From 08b12279050f0128647bcf018308109a433ad546 Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Thu, 16 Apr 2026 21:52:21 +0800 Subject: [PATCH 1/9] refactor(query): split expr_parse into module and expose parse_expr_string Move the generated-expression tokenizer/parser from a single nodedb-query/src/expr_parse.rs into nodedb-query/src/expr_parse/ with separate tokenizer.rs and mod.rs files. Expose a new `parse_expr_string` function in nodedb-sql that delegates to sqlparser-rs so arbitrary DEFAULT/CHECK expressions can be parsed and const-folded at plan time without duplicating grammar logic. --- .../src/{expr_parse.rs => expr_parse/mod.rs} | 186 ++------------ nodedb-query/src/expr_parse/tokenizer.rs | 226 ++++++++++++++++++ nodedb-sql/src/lib.rs | 22 ++ 3 files changed, 268 insertions(+), 166 deletions(-) rename nodedb-query/src/{expr_parse.rs => expr_parse/mod.rs} (78%) create mode 100644 nodedb-query/src/expr_parse/tokenizer.rs diff --git a/nodedb-query/src/expr_parse.rs b/nodedb-query/src/expr_parse/mod.rs similarity index 78% rename from nodedb-query/src/expr_parse.rs rename to nodedb-query/src/expr_parse/mod.rs index 4c4f4762..fe3fb726 100644 --- a/nodedb-query/src/expr_parse.rs +++ b/nodedb-query/src/expr_parse/mod.rs @@ -16,8 +16,11 @@ //! //! Determinism validation: rejects `NOW()`, `RANDOM()`, `NEXTVAL()`, `UUID()`. +mod tokenizer; + use super::expr::{BinaryOp, SqlExpr}; use nodedb_types::Value; +use tokenizer::{Token, TokenKind, tokenize}; /// Parse a SQL expression string into an SqlExpr AST. /// @@ -46,160 +49,11 @@ pub fn parse_generated_expr(text: &str) -> Result<(SqlExpr, Vec), String Ok((expr, deps)) } -// ── Tokenizer ───────────────────────────────────────────────────────── - -#[derive(Debug, Clone)] -struct Token { - text: String, - kind: TokenKind, -} - -#[derive(Debug, Clone, Copy, PartialEq)] -enum TokenKind { - Ident, - Number, - StringLit, - LParen, - RParen, - Comma, - Op, -} - -fn tokenize(input: &str) -> Result, String> { - let bytes = input.as_bytes(); - let mut tokens = Vec::new(); - let mut i = 0; - - while i < bytes.len() { - let b = bytes[i]; - - // Skip whitespace. - if b.is_ascii_whitespace() { - i += 1; - continue; - } - - // Single-char tokens. - if b == b'(' { - tokens.push(Token { - text: "(".into(), - kind: TokenKind::LParen, - }); - i += 1; - continue; - } - if b == b')' { - tokens.push(Token { - text: ")".into(), - kind: TokenKind::RParen, - }); - i += 1; - continue; - } - if b == b',' { - tokens.push(Token { - text: ",".into(), - kind: TokenKind::Comma, - }); - i += 1; - continue; - } - - // Two-char operators. - if i + 1 < bytes.len() { - let two = &input[i..i + 2]; - if matches!(two, "<=" | ">=" | "!=" | "<>") { - tokens.push(Token { - text: two.into(), - kind: TokenKind::Op, - }); - i += 2; - continue; - } - if two == "||" { - tokens.push(Token { - text: "||".into(), - kind: TokenKind::Op, - }); - i += 2; - continue; - } - } - - // Single-char operators. - if matches!(b, b'+' | b'-' | b'*' | b'/' | b'%' | b'=' | b'<' | b'>') { - tokens.push(Token { - text: (b as char).to_string(), - kind: TokenKind::Op, - }); - i += 1; - continue; - } - - // String literal. - if b == b'\'' { - let mut s = String::new(); - i += 1; - while i < bytes.len() { - if bytes[i] == b'\'' { - if i + 1 < bytes.len() && bytes[i + 1] == b'\'' { - s.push('\''); - i += 2; - continue; - } - i += 1; - break; - } - s.push(bytes[i] as char); - i += 1; - } - tokens.push(Token { - text: s, - kind: TokenKind::StringLit, - }); - continue; - } - - // Number. - if b.is_ascii_digit() || (b == b'.' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit()) - { - let start = i; - while i < bytes.len() && (bytes[i].is_ascii_digit() || bytes[i] == b'.') { - i += 1; - } - tokens.push(Token { - text: input[start..i].to_string(), - kind: TokenKind::Number, - }); - continue; - } - - // Identifier or keyword. - if b.is_ascii_alphabetic() || b == b'_' { - let start = i; - while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') { - i += 1; - } - tokens.push(Token { - text: input[start..i].to_string(), - kind: TokenKind::Ident, - }); - continue; - } - - return Err(format!("unexpected character: '{}'", b as char)); - } - - Ok(tokens) -} - // ── Recursive descent parser ────────────────────────────────────────── /// Maximum recursion depth for nested parentheses / sub-expressions. -/// Exceeding this limit returns `Err` instead of overflowing the stack. const MAX_EXPR_DEPTH: usize = 128; -/// Parse an expression (lowest precedence: OR). fn parse_expr(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result { parse_or(tokens, pos, depth) } @@ -304,13 +158,11 @@ fn parse_multiplicative( } fn parse_unary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result { - // Unary minus. if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op && tokens[*pos].text == "-" { *pos += 1; let expr = parse_primary(tokens, pos, depth)?; return Ok(SqlExpr::Negate(Box::new(expr))); } - // NOT if peek_keyword(tokens, *pos, "NOT") { *pos += 1; let expr = parse_primary(tokens, pos, depth)?; @@ -327,7 +179,6 @@ fn parse_primary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result let token = &tokens[*pos]; match token.kind { - // Parenthesized expression. TokenKind::LParen => { *depth += 1; if *depth > MAX_EXPR_DEPTH { @@ -342,7 +193,6 @@ fn parse_primary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result Ok(expr) } - // Number literal. TokenKind::Number => { *pos += 1; if let Ok(i) = token.text.parse::() { @@ -354,13 +204,11 @@ fn parse_primary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result } } - // String literal. TokenKind::StringLit => { *pos += 1; Ok(SqlExpr::Literal(Value::String(token.text.clone()))) } - // Identifier: column ref, function call, keyword (NULL, TRUE, FALSE, CASE, COALESCE). TokenKind::Ident => { let name = token.text.clone(); let upper = name.to_uppercase(); @@ -376,7 +224,6 @@ fn parse_primary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result Ok(SqlExpr::Coalesce(args)) } _ => { - // Function call: IDENT(args). if *pos < tokens.len() && tokens[*pos].kind == TokenKind::LParen { let args = parse_arg_list(tokens, pos, depth)?; Ok(SqlExpr::Function { @@ -384,7 +231,6 @@ fn parse_primary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result args, }) } else { - // Column reference. Ok(SqlExpr::Column(name.to_lowercase())) } } @@ -395,7 +241,6 @@ fn parse_primary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result } } -/// Parse `CASE WHEN cond THEN result [WHEN ... THEN ...] [ELSE result] END`. fn parse_case(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result { let mut when_thens = Vec::new(); let mut else_expr = None; @@ -429,7 +274,6 @@ fn parse_case(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result Result, String> { + let chars: Vec<(usize, char)> = input.char_indices().collect(); + let mut tokens = Vec::new(); + let mut i = 0; + + while i < chars.len() { + let (_, ch) = chars[i]; + + // Skip whitespace. + if ch.is_ascii_whitespace() { + i += 1; + continue; + } + + // Single-char structural tokens. + if ch == '(' { + tokens.push(Token { + text: "(".into(), + kind: TokenKind::LParen, + }); + i += 1; + continue; + } + if ch == ')' { + tokens.push(Token { + text: ")".into(), + kind: TokenKind::RParen, + }); + i += 1; + continue; + } + if ch == ',' { + tokens.push(Token { + text: ",".into(), + kind: TokenKind::Comma, + }); + i += 1; + continue; + } + + // Two-char operators: <=, >=, !=, <>, || + if i + 1 < chars.len() { + let (_, next_ch) = chars[i + 1]; + let two: String = [ch, next_ch].iter().collect(); + if matches!(two.as_str(), "<=" | ">=" | "!=" | "<>" | "||") { + tokens.push(Token { + text: two, + kind: TokenKind::Op, + }); + i += 2; + continue; + } + } + + // Single-char operators. + if matches!(ch, '+' | '-' | '*' | '/' | '%' | '=' | '<' | '>') { + tokens.push(Token { + text: ch.to_string(), + kind: TokenKind::Op, + }); + i += 1; + continue; + } + + // String literal (single-quoted). + if ch == '\'' { + let mut s = String::new(); + i += 1; + while i < chars.len() { + let (_, c) = chars[i]; + if c == '\'' { + // Check for '' escape. + if i + 1 < chars.len() && chars[i + 1].1 == '\'' { + s.push('\''); + i += 2; + continue; + } + i += 1; + break; + } + s.push(c); + i += 1; + } + tokens.push(Token { + text: s, + kind: TokenKind::StringLit, + }); + continue; + } + + // Number. + if ch.is_ascii_digit() + || (ch == '.' && i + 1 < chars.len() && chars[i + 1].1.is_ascii_digit()) + { + let start_byte = chars[i].0; + let start_i = i; + while i < chars.len() && (chars[i].1.is_ascii_digit() || chars[i].1 == '.') { + i += 1; + } + let end_byte = if i < chars.len() { + chars[i].0 + } else { + input.len() + }; + tokens.push(Token { + text: input[start_byte..end_byte].to_string(), + kind: TokenKind::Number, + }); + let _ = start_i; // suppress unused warning + continue; + } + + // Identifier or keyword (ASCII letters, digits, underscore). + if ch.is_ascii_alphabetic() || ch == '_' { + let start_byte = chars[i].0; + while i < chars.len() && (chars[i].1.is_ascii_alphanumeric() || chars[i].1 == '_') { + i += 1; + } + let end_byte = if i < chars.len() { + chars[i].0 + } else { + input.len() + }; + tokens.push(Token { + text: input[start_byte..end_byte].to_string(), + kind: TokenKind::Ident, + }); + continue; + } + + // Non-ASCII characters in an unquoted context: skip gracefully. + // This can happen with stray Unicode in expressions; previously this + // caused a panic. We emit an error with the character for diagnostics. + return Err(format!( + "unexpected character: '{ch}' (U+{:04X})", + ch as u32 + )); + } + + Ok(tokens) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn ascii_expression() { + let tokens = tokenize("price * (1 + tax_rate)").unwrap(); + assert_eq!(tokens.len(), 7); + } + + #[test] + fn cjk_string_literal() { + let tokens = tokenize("'你好' || name").unwrap(); + assert_eq!(tokens.len(), 3); + assert_eq!(tokens[0].kind, TokenKind::StringLit); + assert_eq!(tokens[0].text, "你好"); + } + + #[test] + fn emoji_string_literal() { + let tokens = tokenize("'🎉' || tag").unwrap(); + assert_eq!(tokens.len(), 3); + assert_eq!(tokens[0].text, "🎉"); + } + + #[test] + fn two_char_op_after_multibyte_string() { + // Previously panicked: &input[i..i+2] crossed a char boundary. + let tokens = tokenize("'你' || x").unwrap(); + assert_eq!(tokens.len(), 3); + assert_eq!(tokens[1].text, "||"); + } + + #[test] + fn escaped_quote_in_string() { + let tokens = tokenize("'it''s'").unwrap(); + assert_eq!(tokens.len(), 1); + assert_eq!(tokens[0].text, "it's"); + } + + #[test] + fn latin_diacritics_in_string() { + let tokens = tokenize("'café'").unwrap(); + assert_eq!(tokens[0].text, "café"); + } + + #[test] + fn comparison_after_cjk() { + let tokens = tokenize("name != '禁止'").unwrap(); + assert_eq!(tokens.len(), 3); + assert_eq!(tokens[1].text, "!="); + assert_eq!(tokens[2].text, "禁止"); + } +} diff --git a/nodedb-sql/src/lib.rs b/nodedb-sql/src/lib.rs index 4b614fe2..8ff3be90 100644 --- a/nodedb-sql/src/lib.rs +++ b/nodedb-sql/src/lib.rs @@ -25,6 +25,28 @@ pub use error::{Result, SqlError}; pub use params::ParamValue; pub use types::*; +/// Parse a standalone SQL expression string into an `SqlExpr`. +/// +/// Used by the DEFAULT expression evaluator to handle arbitrary expressions +/// (e.g. `upper('x')`, `1 + 2`) that don't match the hard-coded keyword list. +pub fn parse_expr_string(expr_text: &str) -> Result { + use sqlparser::dialect::GenericDialect; + use sqlparser::parser::Parser; + + let dialect = GenericDialect {}; + let ast_expr = Parser::new(&dialect) + .try_with_sql(expr_text) + .map_err(|e| SqlError::Parse { + detail: e.to_string(), + })? + .parse_expr() + .map_err(|e| SqlError::Parse { + detail: e.to_string(), + })?; + + resolver::expr::convert_expr(&ast_expr) +} + use functions::registry::FunctionRegistry; use parser::preprocess; use parser::statement::{StatementKind, classify, parse_sql}; From 58071afb74df4435110f7fd9e0b4e00dfaf27b86 Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Thu, 16 Apr 2026 21:52:37 +0800 Subject: [PATCH 2/9] feat(recursive-cte): implement join-link based working-table iteration Add a join_link field (collection_field, working_table_field) to RecursiveScan that drives proper tree-traversal CTEs. Each iteration now builds a hash-set of values from the frontier's working_field and finds collection rows whose collection_field is in that set, matching the SQL INNER JOIN ON semantics of standard recursive CTEs. Previously the recursive step applied filters to the full collection without any join relationship to the previous iteration, producing incorrect results for parent/child tree queries. Also handle strict (Binary Tuple) encoded collections in the recursive executor by converting through the schema before filter evaluation. --- nodedb-sql/src/planner/cte.rs | 236 +++++++++++++++-- nodedb-sql/src/types.rs | 5 + nodedb/src/bridge/physical_plan/query.rs | 5 + .../planner/sql_plan_convert/convert.rs | 9 +- .../control/planner/sql_plan_convert/scan.rs | 69 +++-- .../planner/sql_plan_convert/scan_params.rs | 13 + .../native/dispatch/plan_builder/query.rs | 1 + nodedb/src/data/executor/dispatch/other.rs | 2 + .../src/data/executor/handlers/recursive.rs | 241 ++++++++++++++---- 9 files changed, 482 insertions(+), 99 deletions(-) diff --git a/nodedb-sql/src/planner/cte.rs b/nodedb-sql/src/planner/cte.rs index 06a994da..77ce72b3 100644 --- a/nodedb-sql/src/planner/cte.rs +++ b/nodedb-sql/src/planner/cte.rs @@ -4,9 +4,15 @@ use sqlparser::ast::{self, Query, SetExpr}; use crate::error::{Result, SqlError}; use crate::functions::registry::FunctionRegistry; +use crate::parser::normalize::{normalize_ident, normalize_object_name}; use crate::types::*; /// Plan a WITH RECURSIVE query. +/// +/// Supports table-based recursive CTEs where the base case scans a real +/// collection and the recursive step references both the collection and +/// the CTE. Value-generating CTEs (no underlying collection) return an +/// explicit unsupported error. pub fn plan_recursive_cte( query: &Query, catalog: &dyn SqlCatalog, @@ -16,41 +22,227 @@ pub fn plan_recursive_cte( detail: "expected WITH clause".into(), })?; - // Get the CTE definition. let cte = with.cte_tables.first().ok_or_else(|| SqlError::Parse { detail: "empty WITH clause".into(), })?; + let cte_name = normalize_ident(&cte.alias.name); + let cte_query = &cte.query; // The CTE body should be a UNION of base case and recursive case. - match &*cte_query.body { + let (left, right, set_quantifier) = match &*cte_query.body { SetExpr::SetOperation { op: ast::SetOperator::Union, left, right, - .. + set_quantifier, + } => (left, right, set_quantifier), + _ => { + return Err(SqlError::Unsupported { + detail: "WITH RECURSIVE requires UNION in CTE body".into(), + }); + } + }; + + // UNION ALL → distinct = false; UNION → distinct = true. + let distinct = !matches!(set_quantifier, ast::SetQuantifier::All); + + // Plan the base case (should not reference the CTE name). + let base = plan_cte_branch(left, catalog, functions)?; + + // Extract the source collection from the base case. + let collection = extract_collection(&base).unwrap_or_default(); + + // Plan the recursive branch. The recursive branch references the CTE + // name in its FROM clause — either directly (value-gen) or via a JOIN + // with a real table. We attempt to plan it; if it fails because the + // CTE name isn't in the catalog, we try to extract the real table from + // a JOIN and use it with the CTE self-reference as the recursive filter. + let (recursive_filters, join_link) = match plan_cte_branch(right, catalog, functions) { + Ok(plan) => (extract_filters(&plan), None), + Err(_) => { + // The recursive branch references the CTE name. Try to extract + // the real collection, filters, and join link from the AST. + extract_recursive_info(right, &cte_name)? + } + }; + + if collection.is_empty() { + return Err(SqlError::Unsupported { + detail: "WITH RECURSIVE requires a base case that scans a collection; \ + value-generating recursive CTEs are not yet supported" + .into(), + }); + } + + Ok(SqlPlan::RecursiveScan { + collection, + base_filters: extract_filters(&base), + recursive_filters, + join_link, + max_iterations: 100, + distinct, + limit: 10000, + }) +} + +/// Extract recursive info from the AST when normal planning fails +/// because the FROM clause references the CTE name. +/// +/// Returns `(filters, join_link)` where `join_link` is the +/// `(collection_field, working_table_field)` pair for the working-table +/// hash-join. +/// +/// Handles the common tree-traversal pattern: +/// `SELECT t.id FROM tree t INNER JOIN cte_name d ON t.parent_id = d.id` +/// → join_link = `("parent_id", "id")` +/// `(filters, join_link)` where `join_link` is `(collection_field, working_table_field)`. +type RecursiveInfo = (Vec, Option<(String, String)>); + +fn extract_recursive_info(expr: &SetExpr, cte_name: &str) -> Result { + let select = match expr { + SetExpr::Select(s) => s, + _ => { + return Err(SqlError::Unsupported { + detail: "recursive CTE branch must be SELECT".into(), + }); + } + }; + + let mut real_table_alias = None; + let mut cte_alias = None; + let mut join_on_expr = None; + + for from in &select.from { + let table_name = extract_table_name(&from.relation); + let table_alias = extract_table_alias(&from.relation); + + if let Some(name) = &table_name { + if name.eq_ignore_ascii_case(cte_name) { + cte_alias = table_alias.or_else(|| Some(name.clone())); + } else { + real_table_alias = table_alias.or_else(|| Some(name.clone())); + } + } + + for join in &from.joins { + let join_table = extract_table_name(&join.relation); + let join_alias = extract_table_alias(&join.relation); + if let Some(jt) = &join_table { + if jt.eq_ignore_ascii_case(cte_name) { + cte_alias = join_alias.or_else(|| Some(jt.clone())); + if let Some(cond) = extract_join_on_condition(&join.join_operator) { + join_on_expr = Some(cond.clone()); + } + } else { + real_table_alias = join_alias.or_else(|| Some(jt.clone())); + if join_on_expr.is_none() + && let Some(cond) = extract_join_on_condition(&join.join_operator) + { + join_on_expr = Some(cond.clone()); + } + } + } + } + } + + // Extract the join link from the ON condition. + let join_link = if let (Some(real_alias), Some(cte_al), Some(on_expr)) = + (&real_table_alias, &cte_alias, &join_on_expr) + { + extract_equi_link(on_expr, real_alias, cte_al) + } else { + None + }; + + // Convert the WHERE clause to filters if present. + let mut filters = Vec::new(); + if let Some(where_expr) = &select.selection { + let converted = crate::resolver::expr::convert_expr(where_expr)?; + filters.push(Filter { + expr: FilterExpr::Expr(converted), + }); + } + + Ok((filters, join_link)) +} + +/// Extract `(collection_field, cte_field)` from an equi-join ON clause. +/// +/// Given `t.parent_id = d.id` where `t` is the real table alias and `d` +/// is the CTE alias, returns `("parent_id", "id")`. +fn extract_equi_link( + expr: &ast::Expr, + real_alias: &str, + cte_alias: &str, +) -> Option<(String, String)> { + match expr { + ast::Expr::BinaryOp { + left, + op: ast::BinaryOperator::Eq, + right, } => { - let base = plan_cte_branch(left, catalog, functions)?; - let recursive = plan_cte_branch(right, catalog, functions)?; - - // Extract collection and filters from base/recursive plans. - let collection = extract_collection(&base) - .or_else(|| extract_collection(&recursive)) - .unwrap_or_default(); - - Ok(SqlPlan::RecursiveScan { - collection, - base_filters: extract_filters(&base), - recursive_filters: extract_filters(&recursive), - max_iterations: 100, - distinct: true, - limit: 10000, - }) + let left_parts = extract_qualified_column(left)?; + let right_parts = extract_qualified_column(right)?; + + // Determine which side is the real table and which is the CTE. + if left_parts.0.eq_ignore_ascii_case(real_alias) + && right_parts.0.eq_ignore_ascii_case(cte_alias) + { + Some((left_parts.1, right_parts.1)) + } else if right_parts.0.eq_ignore_ascii_case(real_alias) + && left_parts.0.eq_ignore_ascii_case(cte_alias) + { + Some((right_parts.1, left_parts.1)) + } else { + None + } } - _ => Err(SqlError::Unsupported { - detail: "WITH RECURSIVE requires UNION in CTE body".into(), - }), + // For AND-combined conditions, take the first equi-link found. + ast::Expr::BinaryOp { + left, + op: ast::BinaryOperator::And, + right, + } => extract_equi_link(left, real_alias, cte_alias) + .or_else(|| extract_equi_link(right, real_alias, cte_alias)), + _ => None, + } +} + +/// Extract `(table_or_alias, column)` from a qualified column reference. +fn extract_qualified_column(expr: &ast::Expr) -> Option<(String, String)> { + match expr { + ast::Expr::CompoundIdentifier(parts) if parts.len() == 2 => { + Some((normalize_ident(&parts[0]), normalize_ident(&parts[1]))) + } + _ => None, + } +} + +fn extract_table_name(relation: &ast::TableFactor) -> Option { + match relation { + ast::TableFactor::Table { name, .. } => Some(normalize_object_name(name)), + _ => None, + } +} + +fn extract_table_alias(relation: &ast::TableFactor) -> Option { + match relation { + ast::TableFactor::Table { alias, .. } => alias.as_ref().map(|a| normalize_ident(&a.name)), + _ => None, + } +} + +fn extract_join_on_condition(op: &ast::JoinOperator) -> Option<&ast::Expr> { + use ast::JoinOperator::*; + let constraint = match op { + Inner(c) | LeftOuter(c) | RightOuter(c) | FullOuter(c) => c, + _ => return None, + }; + match constraint { + ast::JoinConstraint::On(expr) => Some(expr), + _ => None, } } diff --git a/nodedb-sql/src/types.rs b/nodedb-sql/src/types.rs index 9718275f..31dd4367 100644 --- a/nodedb-sql/src/types.rs +++ b/nodedb-sql/src/types.rs @@ -197,6 +197,11 @@ pub enum SqlPlan { collection: String, base_filters: Vec, recursive_filters: Vec, + /// Equi-join link for tree-traversal recursion: + /// `(collection_field, working_table_field)`. + /// e.g. `("parent_id", "id")` means each iteration finds rows + /// where `collection.parent_id` matches a `working_table.id`. + join_link: Option<(String, String)>, max_iterations: usize, distinct: bool, limit: usize, diff --git a/nodedb/src/bridge/physical_plan/query.rs b/nodedb/src/bridge/physical_plan/query.rs index 1a5122aa..7eb18ab4 100644 --- a/nodedb/src/bridge/physical_plan/query.rs +++ b/nodedb/src/bridge/physical_plan/query.rs @@ -187,6 +187,11 @@ pub enum QueryOp { base_filters: Vec, /// Recursive step filters (applied to working table each iteration). recursive_filters: Vec, + /// Equi-join link for tree-traversal recursion: + /// `(collection_field, working_table_field)`. + /// Each iteration finds rows where `collection_field` value + /// matches a `working_table_field` value from the previous iteration. + join_link: Option<(String, String)>, /// Maximum iterations to prevent infinite loops. Default: 100. max_iterations: usize, /// Whether to deduplicate results (UNION vs UNION ALL). diff --git a/nodedb/src/control/planner/sql_plan_convert/convert.rs b/nodedb/src/control/planner/sql_plan_convert/convert.rs index 028c6729..48259035 100644 --- a/nodedb/src/control/planner/sql_plan_convert/convert.rs +++ b/nodedb/src/control/planner/sql_plan_convert/convert.rs @@ -134,15 +134,16 @@ pub(super) fn convert_one( right, on, join_type, + condition, limit, projection, filters, - .. } => super::scan::convert_join(super::scan_params::JoinPlanParams { left, right, on, join_type, + condition, limit, projection, filters, @@ -275,18 +276,20 @@ pub(super) fn convert_one( collection, base_filters, recursive_filters, + join_link, max_iterations, distinct, limit, - } => super::scan::convert_recursive_scan( + } => super::scan::convert_recursive_scan(super::scan_params::RecursiveScanParams { collection, base_filters, recursive_filters, + join_link, max_iterations, distinct, limit, tenant_id, - ), + }), SqlPlan::Cte { definitions, outer } => { super::set_ops::convert_cte(definitions, outer, tenant_id, ctx) diff --git a/nodedb/src/control/planner/sql_plan_convert/scan.rs b/nodedb/src/control/planner/sql_plan_convert/scan.rs index 3d596fa5..c5ab8701 100644 --- a/nodedb/src/control/planner/sql_plan_convert/scan.rs +++ b/nodedb/src/control/planner/sql_plan_convert/scan.rs @@ -13,15 +13,49 @@ use super::aggregate::{ }; use super::convert::convert_one; use super::expr::convert_sort_keys; -use super::filter::serialize_filters; +use super::filter::{expr_filter_qualified, serialize_filters}; use super::scan_params::{ - HybridSearchParams, JoinPlanParams, ScanParams, SpatialScanParams, TimeseriesScanParams, + HybridSearchParams, JoinPlanParams, RecursiveScanParams, ScanParams, SpatialScanParams, + TimeseriesScanParams, }; use super::value::{ extract_time_range, row_to_msgpack, sql_value_to_bytes, sql_value_to_string, write_msgpack_array_header, }; +/// Serialize WHERE filters + non-equi join condition into a single `Vec`. +/// +/// The non-equi condition (from the ON clause) is appended as a +/// `FilterOp::Expr` ScanFilter so the join executor evaluates it on +/// merged rows alongside any post-join WHERE filters. +fn serialize_join_filters( + filters: &[Filter], + condition: &Option, +) -> crate::Result> { + match condition { + None => serialize_filters(filters), + Some(cond) => { + // Deserialize existing filters (if any), append condition, re-serialize. + let mut scan_filters: Vec = + if !filters.is_empty() { + let base = serialize_filters(filters)?; + if base.is_empty() { + Vec::new() + } else { + zerompk::from_msgpack(&base).unwrap_or_default() + } + } else { + Vec::new() + }; + scan_filters.push(expr_filter_qualified(cond)); + zerompk::to_msgpack_vec(&scan_filters).map_err(|e| crate::Error::Serialization { + format: "msgpack".into(), + detail: format!("join filter serialization: {e}"), + }) + } + } +} + pub(super) fn convert_scan(p: ScanParams<'_>) -> crate::Result> { let ScanParams { collection, @@ -153,6 +187,7 @@ pub(super) fn convert_join(p: JoinPlanParams<'_>) -> crate::Result) -> crate::Result) -> crate::Result) -> crate::Result, ) -> crate::Result> { - let vshard = VShardId::from_collection(collection); + let vshard = VShardId::from_collection(p.collection); Ok(vec![PhysicalTask { - tenant_id, + tenant_id: p.tenant_id, vshard_id: vshard, plan: PhysicalPlan::Query(QueryOp::RecursiveScan { - collection: collection.into(), - base_filters: serialize_filters(base_filters)?, - recursive_filters: serialize_filters(recursive_filters)?, - max_iterations: *max_iterations, - distinct: *distinct, - limit: *limit, + collection: p.collection.into(), + base_filters: serialize_filters(p.base_filters)?, + recursive_filters: serialize_filters(p.recursive_filters)?, + join_link: p.join_link.clone(), + max_iterations: *p.max_iterations, + distinct: *p.distinct, + limit: *p.limit, }), post_set_op: PostSetOp::None, }]) diff --git a/nodedb/src/control/planner/sql_plan_convert/scan_params.rs b/nodedb/src/control/planner/sql_plan_convert/scan_params.rs index adc1ad16..1ad4adea 100644 --- a/nodedb/src/control/planner/sql_plan_convert/scan_params.rs +++ b/nodedb/src/control/planner/sql_plan_convert/scan_params.rs @@ -26,6 +26,7 @@ pub(super) struct JoinPlanParams<'a> { pub right: &'a SqlPlan, pub on: &'a [(String, String)], pub join_type: &'a nodedb_sql::types::JoinType, + pub condition: &'a Option, pub limit: &'a usize, pub projection: &'a [Projection], pub filters: &'a [Filter], @@ -33,6 +34,18 @@ pub(super) struct JoinPlanParams<'a> { pub ctx: &'a ConvertContext, } +/// Parameters for `convert_recursive_scan`. +pub(super) struct RecursiveScanParams<'a> { + pub collection: &'a str, + pub base_filters: &'a [Filter], + pub recursive_filters: &'a [Filter], + pub join_link: &'a Option<(String, String)>, + pub max_iterations: &'a usize, + pub distinct: &'a bool, + pub limit: &'a usize, + pub tenant_id: TenantId, +} + /// Parameters for `convert_timeseries_scan`. pub(super) struct TimeseriesScanParams<'a> { pub collection: &'a str, diff --git a/nodedb/src/control/server/native/dispatch/plan_builder/query.rs b/nodedb/src/control/server/native/dispatch/plan_builder/query.rs index 00259c83..9f767c59 100644 --- a/nodedb/src/control/server/native/dispatch/plan_builder/query.rs +++ b/nodedb/src/control/server/native/dispatch/plan_builder/query.rs @@ -16,6 +16,7 @@ pub(crate) fn build_recursive_scan( collection: collection.to_string(), base_filters, recursive_filters: Vec::new(), + join_link: None, max_iterations: 100, distinct: true, limit, diff --git a/nodedb/src/data/executor/dispatch/other.rs b/nodedb/src/data/executor/dispatch/other.rs index 1e1ca617..5adc2c77 100644 --- a/nodedb/src/data/executor/dispatch/other.rs +++ b/nodedb/src/data/executor/dispatch/other.rs @@ -139,6 +139,7 @@ impl CoreLoop { collection, base_filters, recursive_filters, + join_link, max_iterations, distinct, limit, @@ -148,6 +149,7 @@ impl CoreLoop { collection, base_filters, recursive_filters, + join_link.as_ref(), *max_iterations, *distinct, *limit, diff --git a/nodedb/src/data/executor/handlers/recursive.rs b/nodedb/src/data/executor/handlers/recursive.rs index 6fc48579..6c2fd1b9 100644 --- a/nodedb/src/data/executor/handlers/recursive.rs +++ b/nodedb/src/data/executor/handlers/recursive.rs @@ -1,8 +1,9 @@ //! Recursive CTE handler: iterative fixed-point execution. //! //! Executes the base query once to seed the working table, then -//! repeatedly executes the recursive query until no new rows are -//! produced (fixed point) or max_iterations is reached. +//! repeatedly joins the collection against the working table via +//! the `join_link` until no new rows are produced (fixed point) +//! or `max_iterations` is reached. use std::collections::HashSet; @@ -16,10 +17,11 @@ impl CoreLoop { /// /// Algorithm: /// 1. Seed: scan collection with base_filters → working_table - /// 2. Loop: scan collection with recursive_filters, join against working_table - /// 3. Add new rows to working_table - /// 4. Repeat until no new rows or max_iterations reached - /// 5. Return all accumulated rows + /// 2. Loop: for each row in collection, check if `join_link.0` value + /// matches any `join_link.1` value in the working_table → new matches + /// 3. New matches become the working table for the next iteration + /// 4. Accumulate all results + /// 5. Repeat until no new rows or max_iterations reached #[allow(clippy::too_many_arguments)] pub(in crate::data::executor) fn execute_recursive_scan( &mut self, @@ -28,6 +30,7 @@ impl CoreLoop { collection: &str, base_filters: &[u8], recursive_filters: &[u8], + join_link: Option<&(String, String)>, max_iterations: usize, distinct: bool, limit: usize, @@ -66,6 +69,17 @@ impl CoreLoop { } }; + // Check if the collection uses strict (Binary Tuple) encoding. + let config_key = format!("{tid}:{collection}"); + let strict_schema = self.doc_configs.get(&config_key).and_then(|c| { + if let crate::bridge::physical_plan::StorageMode::Strict { ref schema } = c.storage_mode + { + Some(schema.clone()) + } else { + None + } + }); + // Scan all documents once (used for both base and recursive steps). let all_docs = match self.sparse.scan_documents(tid, collection, scan_limit) { Ok(d) => d, @@ -79,15 +93,40 @@ impl CoreLoop { } }; - // Step 1: Seed working table with base query results (raw msgpack). + // Convert raw bytes to msgpack. For strict docs, this requires the schema. + let to_msgpack = |value: &[u8]| -> Option> { + if let Some(ref schema) = strict_schema { + super::super::strict_format::binary_tuple_to_msgpack(value, schema) + } else { + Some(super::super::doc_format::json_to_msgpack(value)) + } + }; + + // Step 1: Seed working table with base query results. let mut results: Vec> = Vec::new(); let mut seen_keys: HashSet = HashSet::new(); + tracing::debug!( + core = self.core_id, + %collection, + all_docs = all_docs.len(), + base_preds = base_preds.len(), + strict = strict_schema.is_some(), + ?join_link, + "recursive CTE: starting seed" + ); + for (_doc_id, value) in &all_docs { - if !base_preds.iter().all(|f| f.matches_binary(value)) { + let mp = match to_msgpack(value) { + Some(m) => m, + None => { + tracing::debug!(core = self.core_id, "to_msgpack returned None"); + continue; + } + }; + if !base_preds.iter().all(|f| f.matches_binary(&mp)) { continue; } - let mp = super::super::doc_format::json_to_msgpack(value); let key = if distinct { nodedb_types::msgpack_to_json_string(&mp).unwrap_or_default() } else { @@ -98,55 +137,135 @@ impl CoreLoop { } } + tracing::debug!( + core = self.core_id, + seed_count = results.len(), + "recursive CTE: seed complete" + ); + // Step 2: Iterate recursive step until fixed point. - let mut prev_count = 0; - for iteration in 0..max_iterations { - if results.len() >= limit || results.len() == prev_count { - break; - } - prev_count = results.len(); - - // The "working table" for this iteration is the rows added in the - // previous iteration. For the recursive step, we scan the collection - // and filter by recursive predicates AND check that the row relates - // to existing working table rows (by matching any field). - // - // In a full SQL recursive CTE, the recursive term references the - // CTE name. Here, we approximate by applying recursive filters to - // the full collection and adding any new matching rows. - let mut new_rows = Vec::new(); - for (doc_id, value) in &all_docs { - if results.len() + new_rows.len() >= limit { + if let Some((collection_field, working_field)) = join_link { + // Working-table hash-join: each iteration finds collection rows + // where `collection_field` matches a `working_field` value from + // the previous iteration's new rows. + let mut frontier = results.clone(); + + for iteration in 0..max_iterations { + if results.len() >= limit || frontier.is_empty() { break; } - if !recursive_preds.iter().all(|f| f.matches_binary(value)) { - continue; + + // Build hash set of working_field values from the frontier. + let frontier_values: HashSet = frontier + .iter() + .filter_map(|row| extract_field_string(row, working_field)) + .collect(); + + if frontier_values.is_empty() { + break; } - let mp = super::super::doc_format::json_to_msgpack(value); - let key = if distinct { - nodedb_types::msgpack_to_json_string(&mp).unwrap_or_default() - } else { - doc_id.clone() - }; - if !distinct || seen_keys.insert(key) { - new_rows.push(mp); + + let mut new_rows = Vec::new(); + for (_doc_id, value) in &all_docs { + if results.len() + new_rows.len() >= limit { + break; + } + + let mp = match to_msgpack(value) { + Some(m) => m, + None => continue, + }; + + // Apply recursive filters (WHERE clause from recursive branch). + if !recursive_preds.iter().all(|f| f.matches_binary(&mp)) { + continue; + } + + // Check join link: collection_field value must be in frontier. + let field_val = match extract_field_string(&mp, collection_field) { + Some(v) => v, + None => continue, + }; + if !frontier_values.contains(&field_val) { + continue; + } + + let key = if distinct { + nodedb_types::msgpack_to_json_string(&mp).unwrap_or_default() + } else { + String::new() + }; + if !distinct || seen_keys.insert(key) { + new_rows.push(mp); + } + } + + if new_rows.is_empty() { + break; } - } - if new_rows.is_empty() { - break; + tracing::debug!( + core = self.core_id, + iteration, + new_rows = new_rows.len(), + total = results.len() + new_rows.len(), + "recursive CTE iteration (join-link)" + ); + frontier = new_rows.clone(); + results.extend(new_rows); } + } else { + // No join link — fall back to filter-only iteration (original behavior). + let mut prev_count = 0; + for iteration in 0..max_iterations { + if results.len() >= limit || results.len() == prev_count { + break; + } + prev_count = results.len(); + + let mut new_rows = Vec::new(); + for (doc_id, value) in &all_docs { + if results.len() + new_rows.len() >= limit { + break; + } + let mp = match to_msgpack(value) { + Some(m) => m, + None => continue, + }; + if !recursive_preds.iter().all(|f| f.matches_binary(&mp)) { + continue; + } + let key = if distinct { + nodedb_types::msgpack_to_json_string(&mp).unwrap_or_default() + } else { + doc_id.clone() + }; + if !distinct || seen_keys.insert(key) { + new_rows.push(mp); + } + } - tracing::debug!( - core = self.core_id, - iteration, - new_rows = new_rows.len(), - total = results.len() + new_rows.len(), - "recursive CTE iteration" - ); - results.extend(new_rows); + if new_rows.is_empty() { + break; + } + + tracing::debug!( + core = self.core_id, + iteration, + new_rows = new_rows.len(), + total = results.len() + new_rows.len(), + "recursive CTE iteration (filter-only)" + ); + results.extend(new_rows); + } } + tracing::debug!( + core = self.core_id, + total = results.len(), + "recursive CTE: iteration complete" + ); + // Truncate to limit. results.truncate(limit); @@ -156,14 +275,24 @@ impl CoreLoop { for row in &results { payload.extend_from_slice(row); } - match Ok::, crate::Error>(payload) { - Ok(payload) => self.response_with_payload(task, payload), - Err(e) => self.response_error( - task, - ErrorCode::Internal { - detail: e.to_string(), - }, - ), + self.response_with_payload(task, payload) + } +} + +/// Extract a field value from a msgpack document as a string for hash lookup. +fn extract_field_string(msgpack_doc: &[u8], field_name: &str) -> Option { + let value = nodedb_types::value_from_msgpack(msgpack_doc).ok()?; + match &value { + nodedb_types::Value::Object(map) => { + let v = map.get(field_name)?; + match v { + nodedb_types::Value::String(s) => Some(s.clone()), + nodedb_types::Value::Integer(i) => Some(i.to_string()), + nodedb_types::Value::Float(f) => Some(f.to_string()), + nodedb_types::Value::Bool(b) => Some(b.to_string()), + _ => Some(format!("{v:?}")), + } } + _ => None, } } From 13a7cde1c9202344fde5b1ed527317eb91e2b477 Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Thu, 16 Apr 2026 21:52:53 +0800 Subject: [PATCH 3/9] fix(planner): correct join non-equi predicates and unsupported join types Fold all non-equi ON predicates with AND instead of silently dropping all but the first, which caused queries with compound ON clauses to produce wrong results. Return explicit errors for NATURAL JOIN and implicit cross-joins (no ON/USING clause) rather than silently succeeding with empty join keys. Add sql_expr_to_bridge_expr_qualified and expr_filter_qualified for join contexts where merged documents use table-qualified field names, and wire the join condition through serialize_join_filters so non-equi ON predicates are evaluated against merged rows alongside WHERE filters. --- nodedb-sql/src/planner/join.rs | 19 +++++- .../control/planner/sql_plan_convert/expr.rs | 65 ++++++++++++++----- .../planner/sql_plan_convert/filter.rs | 14 +++- 3 files changed, 77 insertions(+), 21 deletions(-) diff --git a/nodedb-sql/src/planner/join.rs b/nodedb-sql/src/planner/join.rs index a0f46a3a..225de587 100644 --- a/nodedb-sql/src/planner/join.rs +++ b/nodedb-sql/src/planner/join.rs @@ -204,7 +204,16 @@ fn extract_join_constraint(constraint: &ast::JoinConstraint) -> Result Result Ok((Vec::new(), None)), - ast::JoinConstraint::None => Ok((Vec::new(), None)), + ast::JoinConstraint::Natural => Err(crate::error::SqlError::Unsupported { + detail: "NATURAL JOIN is not supported; use explicit ON or USING clause".into(), + }), + ast::JoinConstraint::None => Err(crate::error::SqlError::Unsupported { + detail: "implicit cross join (no ON/USING clause) is not supported".into(), + }), } } diff --git a/nodedb/src/control/planner/sql_plan_convert/expr.rs b/nodedb/src/control/planner/sql_plan_convert/expr.rs index b2567f6e..d7e85877 100644 --- a/nodedb/src/control/planner/sql_plan_convert/expr.rs +++ b/nodedb/src/control/planner/sql_plan_convert/expr.rs @@ -6,13 +6,36 @@ use super::value::sql_value_to_nodedb_value; /// Convert a `nodedb_sql::types::SqlExpr` (parser AST) to a /// `nodedb_query::expr::SqlExpr` (bridge evaluation type). +/// +/// Column references use the **bare** name (no table qualifier) for +/// single-collection evaluation contexts (WHERE, CHECK, GENERATED). +/// For join contexts where the merged document uses qualified keys +/// (`"t1.col"`), use [`sql_expr_to_bridge_expr_qualified`] instead. pub(super) fn sql_expr_to_bridge_expr(expr: &SqlExpr) -> crate::bridge::expr_eval::SqlExpr { + convert_expr_inner(expr, false) +} + +/// Like [`sql_expr_to_bridge_expr`] but qualifies column references +/// with their table name (`t.col` → `"t.col"`) for join merged docs. +pub(super) fn sql_expr_to_bridge_expr_qualified( + expr: &SqlExpr, +) -> crate::bridge::expr_eval::SqlExpr { + convert_expr_inner(expr, true) +} + +fn convert_expr_inner(expr: &SqlExpr, qualify: bool) -> crate::bridge::expr_eval::SqlExpr { use crate::bridge::expr_eval::SqlExpr as BExpr; match expr { - SqlExpr::Column { name, .. } => BExpr::Column(name.clone()), + SqlExpr::Column { table, name } => { + if qualify { + BExpr::Column(nodedb_sql::planner::qualified_name(table.as_deref(), name)) + } else { + BExpr::Column(name.clone()) + } + } SqlExpr::Literal(v) => BExpr::Literal(sql_value_to_nodedb_value(v)), SqlExpr::BinaryOp { left, op, right } => BExpr::BinaryOp { - left: Box::new(sql_expr_to_bridge_expr(left)), + left: Box::new(convert_expr_inner(left, qualify)), op: match op { nodedb_sql::types::BinaryOp::Add => crate::bridge::expr_eval::BinaryOp::Add, nodedb_sql::types::BinaryOp::Sub => crate::bridge::expr_eval::BinaryOp::Sub, @@ -29,11 +52,14 @@ pub(super) fn sql_expr_to_bridge_expr(expr: &SqlExpr) -> crate::bridge::expr_eva nodedb_sql::types::BinaryOp::Or => crate::bridge::expr_eval::BinaryOp::Or, nodedb_sql::types::BinaryOp::Concat => crate::bridge::expr_eval::BinaryOp::Concat, }, - right: Box::new(sql_expr_to_bridge_expr(right)), + right: Box::new(convert_expr_inner(right, qualify)), }, SqlExpr::Function { name, args, .. } => BExpr::Function { name: name.clone(), - args: args.iter().map(sql_expr_to_bridge_expr).collect(), + args: args + .iter() + .map(|a| convert_expr_inner(a, qualify)) + .collect(), }, SqlExpr::Case { operand, @@ -42,14 +68,19 @@ pub(super) fn sql_expr_to_bridge_expr(expr: &SqlExpr) -> crate::bridge::expr_eva } => BExpr::Case { operand: operand .as_ref() - .map(|e| Box::new(sql_expr_to_bridge_expr(e))), + .map(|e| Box::new(convert_expr_inner(e, qualify))), when_thens: when_then .iter() - .map(|(w, t)| (sql_expr_to_bridge_expr(w), sql_expr_to_bridge_expr(t))) + .map(|(w, t)| { + ( + convert_expr_inner(w, qualify), + convert_expr_inner(t, qualify), + ) + }) .collect(), else_expr: else_expr .as_ref() - .map(|e| Box::new(sql_expr_to_bridge_expr(e))), + .map(|e| Box::new(convert_expr_inner(e, qualify))), }, SqlExpr::Cast { expr, to_type } => { let cast_type = match to_type.to_uppercase().as_str() { @@ -63,18 +94,18 @@ pub(super) fn sql_expr_to_bridge_expr(expr: &SqlExpr) -> crate::bridge::expr_eva _ => crate::bridge::expr_eval::CastType::String, }; BExpr::Cast { - expr: Box::new(sql_expr_to_bridge_expr(expr)), + expr: Box::new(convert_expr_inner(expr, qualify)), to_type: cast_type, } } SqlExpr::Wildcard => BExpr::Column("*".into()), // NOT e / -e → evaluator's Negate (handles both bool and numeric). - SqlExpr::UnaryOp { expr, .. } => BExpr::Negate(Box::new(sql_expr_to_bridge_expr(expr))), + SqlExpr::UnaryOp { expr, .. } => BExpr::Negate(Box::new(convert_expr_inner(expr, qualify))), // `e IS NULL` / `e IS NOT NULL` — direct passthrough. SqlExpr::IsNull { expr, negated } => BExpr::IsNull { - expr: Box::new(sql_expr_to_bridge_expr(expr)), + expr: Box::new(convert_expr_inner(expr, qualify)), negated: *negated, }, @@ -87,9 +118,9 @@ pub(super) fn sql_expr_to_bridge_expr(expr: &SqlExpr) -> crate::bridge::expr_eva high, negated, } => { - let e = sql_expr_to_bridge_expr(expr); - let l = sql_expr_to_bridge_expr(low); - let h = sql_expr_to_bridge_expr(high); + let e = convert_expr_inner(expr, qualify); + let l = convert_expr_inner(low, qualify); + let h = convert_expr_inner(high, qualify); if *negated { let lt = BExpr::BinaryOp { left: Box::new(e.clone()), @@ -134,7 +165,7 @@ pub(super) fn sql_expr_to_bridge_expr(expr: &SqlExpr) -> crate::bridge::expr_eva list, negated, } => { - let target = sql_expr_to_bridge_expr(expr); + let target = convert_expr_inner(expr, qualify); if list.is_empty() { // Empty list: `e IN ()` = false, `e NOT IN ()` = true. return BExpr::Literal(nodedb_types::Value::Bool(*negated)); @@ -157,7 +188,7 @@ pub(super) fn sql_expr_to_bridge_expr(expr: &SqlExpr) -> crate::bridge::expr_eva .map(|item| BExpr::BinaryOp { left: Box::new(target.clone()), op: eq_op, - right: Box::new(sql_expr_to_bridge_expr(item)), + right: Box::new(convert_expr_inner(item, qualify)), }) .reduce(|acc, next| BExpr::BinaryOp { left: Box::new(acc), @@ -178,8 +209,8 @@ pub(super) fn sql_expr_to_bridge_expr(expr: &SqlExpr) -> crate::bridge::expr_eva let call = BExpr::Function { name: "like".into(), args: vec![ - sql_expr_to_bridge_expr(expr), - sql_expr_to_bridge_expr(pattern), + convert_expr_inner(expr, qualify), + convert_expr_inner(pattern, qualify), ], }; if *negated { diff --git a/nodedb/src/control/planner/sql_plan_convert/filter.rs b/nodedb/src/control/planner/sql_plan_convert/filter.rs index 5ccb90a3..ed0162f8 100644 --- a/nodedb/src/control/planner/sql_plan_convert/filter.rs +++ b/nodedb/src/control/planner/sql_plan_convert/filter.rs @@ -122,7 +122,7 @@ fn filter_to_scan_filters(expr: &FilterExpr) -> Vec nodedb_query::scan_filter::ScanFilter { +pub(super) fn expr_filter(expr: &SqlExpr) -> nodedb_query::scan_filter::ScanFilter { nodedb_query::scan_filter::ScanFilter { field: String::new(), op: nodedb_query::scan_filter::FilterOp::Expr, @@ -132,6 +132,18 @@ fn expr_filter(expr: &SqlExpr) -> nodedb_query::scan_filter::ScanFilter { } } +/// Like [`expr_filter`] but qualifies column references with table names +/// for evaluation against join-merged documents. +pub(super) fn expr_filter_qualified(expr: &SqlExpr) -> nodedb_query::scan_filter::ScanFilter { + nodedb_query::scan_filter::ScanFilter { + field: String::new(), + op: nodedb_query::scan_filter::FilterOp::Expr, + value: nodedb_types::Value::Null, + clauses: Vec::new(), + expr: Some(super::expr::sql_expr_to_bridge_expr_qualified(expr)), + } +} + /// Convert a raw `SqlExpr` (from WHERE clause) to a `ScanFilter` list. /// /// Tries to produce simple, field-indexed filters for common cases (direct From aea009081400d3fc1e594b51db53318ed8a0d368 Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Thu, 16 Apr 2026 21:53:06 +0800 Subject: [PATCH 4/9] fix(procedural): harden executor against overflow, runaway loops, and cache collisions Replace wrapping integer arithmetic with checked_add/sub/mul/div/rem in the expression evaluator so overflows return None rather than panicking or silently wrapping. Replace the unlimited trigger budget (1-hour wall clock, MAX iterations) with a trigger_default (100k iterations, 10s) to prevent runaway procedural bodies from pinning Control Plane workers indefinitely. Guard the plan cache against 64-bit hash collisions by verifying the cached body SQL matches before returning a cached block, and evicting on mismatch. --- .../planner/procedural/executor/core/mod.rs | 2 +- .../planner/procedural/executor/eval.rs | 19 +++++++++++++------ .../planner/procedural/executor/fuel.rs | 16 ++++++---------- .../planner/procedural/executor/plan_cache.rs | 13 +++++++++++-- 4 files changed, 31 insertions(+), 19 deletions(-) diff --git a/nodedb/src/control/planner/procedural/executor/core/mod.rs b/nodedb/src/control/planner/procedural/executor/core/mod.rs index 86ec06fa..dbfa9caa 100644 --- a/nodedb/src/control/planner/procedural/executor/core/mod.rs +++ b/nodedb/src/control/planner/procedural/executor/core/mod.rs @@ -100,7 +100,7 @@ impl<'a> StatementExecutor<'a> { block: &ProceduralBlock, bindings: &RowBindings, ) -> crate::Result<()> { - let mut budget = ExecutionBudget::unlimited(); + let mut budget = ExecutionBudget::trigger_default(); self.execute_block_with_exceptions( &block.statements, &block.exception_handlers, diff --git a/nodedb/src/control/planner/procedural/executor/eval.rs b/nodedb/src/control/planner/procedural/executor/eval.rs index ea0bf2c0..343be507 100644 --- a/nodedb/src/control/planner/procedural/executor/eval.rs +++ b/nodedb/src/control/planner/procedural/executor/eval.rs @@ -174,11 +174,11 @@ fn eval_binary_op( match (l, r) { (Value::Integer(a), Value::Integer(b)) => match op { - Plus => Some(Value::Integer(a + b)), - Minus => Some(Value::Integer(a - b)), - Multiply => Some(Value::Integer(a * b)), - Divide if *b != 0 => Some(Value::Integer(a / b)), - Modulo if *b != 0 => Some(Value::Integer(a % b)), + Plus => Some(Value::Integer(a.checked_add(*b)?)), + Minus => Some(Value::Integer(a.checked_sub(*b)?)), + Multiply => Some(Value::Integer(a.checked_mul(*b)?)), + Divide if *b != 0 => Some(Value::Integer(a.checked_div(*b)?)), + Modulo if *b != 0 => Some(Value::Integer(a.checked_rem(*b)?)), Gt => Some(Value::Bool(a > b)), GtEq => Some(Value::Bool(a >= b)), Lt => Some(Value::Bool(a < b)), @@ -191,7 +191,14 @@ fn eval_binary_op( Plus => Some(Value::Float(a + b)), Minus => Some(Value::Float(a - b)), Multiply => Some(Value::Float(a * b)), - Divide if *b != 0.0 => Some(Value::Float(a / b)), + Divide if *b != 0.0 && b.is_finite() && *b != -0.0 => { + let result = a / b; + if result.is_finite() { + Some(Value::Float(result)) + } else { + None + } + } Gt => Some(Value::Bool(a > b)), GtEq => Some(Value::Bool(a >= b)), Lt => Some(Value::Bool(a < b)), diff --git a/nodedb/src/control/planner/procedural/executor/fuel.rs b/nodedb/src/control/planner/procedural/executor/fuel.rs index 88ff3f6b..c2f770e8 100644 --- a/nodedb/src/control/planner/procedural/executor/fuel.rs +++ b/nodedb/src/control/planner/procedural/executor/fuel.rs @@ -32,14 +32,10 @@ impl ExecutionBudget { } } - /// Create an unlimited budget (for triggers with no explicit limits). - pub fn unlimited() -> Self { - Self { - fuel_remaining: u64::MAX, - deadline: Instant::now() + std::time::Duration::from_secs(3600), - max_iterations: u64::MAX, - timeout_secs: 3600, - } + /// Default budget for trigger bodies. Caps iterations and wall-clock + /// time to prevent runaway loops from pinning Control Plane workers. + pub fn trigger_default() -> Self { + Self::new(100_000, 10) } /// Consume one iteration of fuel. Returns error if exhausted. @@ -103,8 +99,8 @@ mod tests { } #[test] - fn unlimited() { - let mut budget = ExecutionBudget::unlimited(); + fn trigger_default_allows_bounded_loops() { + let mut budget = ExecutionBudget::trigger_default(); for _ in 0..10_000 { budget.consume_iteration().unwrap(); } diff --git a/nodedb/src/control/planner/procedural/executor/plan_cache.rs b/nodedb/src/control/planner/procedural/executor/plan_cache.rs index 680d4855..bc90c059 100644 --- a/nodedb/src/control/planner/procedural/executor/plan_cache.rs +++ b/nodedb/src/control/planner/procedural/executor/plan_cache.rs @@ -27,6 +27,7 @@ struct CacheInner { } struct CacheEntry { + body_sql: String, block: Arc, } @@ -58,9 +59,16 @@ impl ProcedureBlockCache { // is uncontended in the cache-hit path (early return). let mut inner = self.inner.lock().unwrap_or_else(|p| p.into_inner()); - // Cache hit — return immediately. + // Cache hit — verify the body matches (not just the hash) to guard + // against 64-bit hash collisions returning the wrong block. if let Some(entry) = inner.entries.get(&key) { - return Ok(Arc::clone(&entry.block)); + if entry.body_sql == body_sql { + return Ok(Arc::clone(&entry.block)); + } + // Hash collision with different body — evict the stale entry + // and fall through to re-parse. + inner.entries.remove(&key); + inner.order.retain(|k| *k != key); } // Cache miss — parse, cache, return. @@ -76,6 +84,7 @@ impl ProcedureBlockCache { inner.entries.insert( key, CacheEntry { + body_sql: body_sql.to_string(), block: Arc::clone(&arc), }, ); From 3e11ca29c9f1323d889996dcf5c70ff9b287c203 Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Thu, 16 Apr 2026 21:53:17 +0800 Subject: [PATCH 5/9] fix(ddl): replace byte-indexed NEW. rewriting with char-safe iteration The check constraint, constraint validator, RLS, trigger, and materialized view DDL parsers rewrote "NEW.col" references using byte indexing after a to_uppercase() call. This produced incorrect results for non-ASCII identifiers and is unsound for multi-byte UTF-8 sequences. Switch all affected parsers to collect chars and iterate over the char slice, matching the "NEW." prefix with eq_ignore_ascii_case on a 4-char window. Also fix the SQL preprocessor's string literal scanner to advance by 2 on '' escapes instead of continuing without incrementing, which previously caused an infinite loop on inputs with escaped single quotes. --- nodedb-sql/src/parser/preprocess.rs | 7 +- .../pgwire/ddl/collection/check_constraint.rs | 68 ++++++++++--------- .../server/pgwire/ddl/constraint/validate.rs | 22 +++--- .../pgwire/ddl/materialized_view/parse.rs | 44 +++++++++--- .../control/server/pgwire/ddl/rls/parse.rs | 52 ++++++++++++-- .../server/pgwire/ddl/trigger/parse.rs | 47 +++++++++---- .../src/control/server/pgwire/handler/plan.rs | 1 + 7 files changed, 168 insertions(+), 73 deletions(-) diff --git a/nodedb-sql/src/parser/preprocess.rs b/nodedb-sql/src/parser/preprocess.rs index d8f2d7ae..2309e1fd 100644 --- a/nodedb-sql/src/parser/preprocess.rs +++ b/nodedb-sql/src/parser/preprocess.rs @@ -270,12 +270,14 @@ fn value_to_json(value: &nodedb_types::Value) -> String { fn find_matching_brace(chars: &[char], start: usize) -> Option { let mut depth = 0; let mut in_string = false; - for i in start..chars.len() { + let mut i = start; + while i < chars.len() { match chars[i] { '\'' if !in_string => in_string = true, '\'' if in_string => { if i + 1 < chars.len() && chars[i + 1] == '\'' { - // Skip escaped quote. + // Skip '' escape — advance past both quotes. + i += 2; continue; } in_string = false; @@ -289,6 +291,7 @@ fn find_matching_brace(chars: &[char], start: usize) -> Option { } _ => {} } + i += 1; } None } diff --git a/nodedb/src/control/server/pgwire/ddl/collection/check_constraint.rs b/nodedb/src/control/server/pgwire/ddl/collection/check_constraint.rs index da6d37e5..6c037a87 100644 --- a/nodedb/src/control/server/pgwire/ddl/collection/check_constraint.rs +++ b/nodedb/src/control/server/pgwire/ddl/collection/check_constraint.rs @@ -271,24 +271,24 @@ fn restructure_subquery_check(expr: &str) -> RestructuredCheck { /// Converts `NEW.amount > 0` → `amount > 0` so the expression can be parsed /// as bare column references by `parse_generated_expr`. fn strip_new_prefix(sql: &str) -> String { + let chars: Vec = sql.chars().collect(); let mut result = String::with_capacity(sql.len()); - let upper = sql.to_uppercase(); - let bytes = sql.as_bytes(); let mut i = 0; - while i < bytes.len() { - if i + 4 <= bytes.len() && upper[i..].starts_with("NEW.") { - // Check word boundary before. - if i > 0 && (bytes[i - 1].is_ascii_alphanumeric() || bytes[i - 1] == b'_') { - result.push(bytes[i] as char); - i += 1; + while i < chars.len() { + if i + 4 <= chars.len() { + let window: String = chars[i..i + 4].iter().collect(); + if window.eq_ignore_ascii_case("NEW.") { + if i > 0 && (chars[i - 1].is_ascii_alphanumeric() || chars[i - 1] == '_') { + result.push(chars[i]); + i += 1; + continue; + } + i += 4; continue; } - // Skip "NEW." (4 chars). - i += 4; - continue; } - result.push(bytes[i] as char); + result.push(chars[i]); i += 1; } result @@ -330,33 +330,35 @@ fn substitute_new_refs(sql: &str, fields: &HashMap) /// Replace any remaining `NEW.xxx` references (not matched by known fields) with NULL. fn replace_remaining_new_refs(text: &str) -> String { - let upper = text.to_uppercase(); + let chars: Vec = text.chars().collect(); let mut result = String::with_capacity(text.len()); let mut i = 0; - let bytes = text.as_bytes(); - while i < bytes.len() { + while i < chars.len() { // Check for "NEW." prefix (case insensitive). - if i + 4 <= bytes.len() && upper[i..].starts_with("NEW.") { - // Check word boundary before. - if i > 0 && (bytes[i - 1].is_ascii_alphanumeric() || bytes[i - 1] == b'_') { - result.push(bytes[i] as char); - i += 1; - continue; - } - // Find the end of the identifier after "NEW.". - let start = i + 4; - let mut end = start; - while end < bytes.len() && (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_') { - end += 1; - } - if end > start { - result.push_str("NULL"); - i = end; - continue; + if i + 4 <= chars.len() { + let window: String = chars[i..i + 4].iter().collect(); + if window.eq_ignore_ascii_case("NEW.") { + if i > 0 && (chars[i - 1].is_ascii_alphanumeric() || chars[i - 1] == '_') { + result.push(chars[i]); + i += 1; + continue; + } + // Find the end of the identifier after "NEW.". + let start = i + 4; + let mut end = start; + while end < chars.len() && (chars[end].is_ascii_alphanumeric() || chars[end] == '_') + { + end += 1; + } + if end > start { + result.push_str("NULL"); + i = end; + continue; + } } } - result.push(bytes[i] as char); + result.push(chars[i]); i += 1; } result diff --git a/nodedb/src/control/server/pgwire/ddl/constraint/validate.rs b/nodedb/src/control/server/pgwire/ddl/constraint/validate.rs index 0d20ccb6..3f99fb86 100644 --- a/nodedb/src/control/server/pgwire/ddl/constraint/validate.rs +++ b/nodedb/src/control/server/pgwire/ddl/constraint/validate.rs @@ -29,21 +29,23 @@ pub(super) fn validate_subquery_pattern(check_sql: &str) -> PgWireResult<()> { /// Strip `NEW.` prefix for validation parsing. pub(super) fn strip_new_prefix_for_validation(sql: &str) -> String { - let upper = sql.to_uppercase(); - let bytes = sql.as_bytes(); + let chars: Vec = sql.chars().collect(); let mut result = String::with_capacity(sql.len()); let mut i = 0; - while i < bytes.len() { - if i + 4 <= bytes.len() && upper[i..].starts_with("NEW.") { - if i > 0 && (bytes[i - 1].is_ascii_alphanumeric() || bytes[i - 1] == b'_') { - result.push(bytes[i] as char); - i += 1; + while i < chars.len() { + if i + 4 <= chars.len() { + let window: String = chars[i..i + 4].iter().collect(); + if window.eq_ignore_ascii_case("NEW.") { + if i > 0 && (chars[i - 1].is_ascii_alphanumeric() || chars[i - 1] == '_') { + result.push(chars[i]); + i += 1; + continue; + } + i += 4; continue; } - i += 4; - continue; } - result.push(bytes[i] as char); + result.push(chars[i]); i += 1; } result diff --git a/nodedb/src/control/server/pgwire/ddl/materialized_view/parse.rs b/nodedb/src/control/server/pgwire/ddl/materialized_view/parse.rs index 395f9652..9b832c14 100644 --- a/nodedb/src/control/server/pgwire/ddl/materialized_view/parse.rs +++ b/nodedb/src/control/server/pgwire/ddl/materialized_view/parse.rs @@ -50,16 +50,11 @@ pub fn parse_create_mv(sql: &str) -> PgWireResult<(String, String, String, Strin .ok_or_else(|| sqlstate_error("42601", "expected AS SELECT ... clause"))?; let query_start = after_on_start + as_pos + KW_AS.len(); - // Find end of query: WITH clause or end of string. - let remaining = &upper[query_start..]; - let with_pos = remaining.find(" WITH").or_else(|| { - if remaining.trim_start().starts_with("WITH") { - Some(0) - } else { - None - } - }); - let query_end = with_pos.map(|p| query_start + p).unwrap_or(sql.len()); + // Find end of query: a trailing `WITH (...)` options clause, or end of string. + // We must not match a CTE `WITH cte AS (...)` inside the SELECT body. + // The options clause is always `WITH (` — a CTE `WITH` is followed by + // an identifier, never `(`. + let query_end = find_trailing_with_options(&upper, query_start).unwrap_or(sql.len()); let query_sql = sql[query_start..query_end].trim().to_string(); if query_sql.is_empty() { @@ -71,6 +66,35 @@ pub fn parse_create_mv(sql: &str) -> PgWireResult<(String, String, String, Strin Ok((name, source, query_sql, refresh_mode)) } +/// Find a trailing `WITH (...)` options clause that is NOT a CTE. +/// +/// Scans backward from the end of the string for `WITH` followed (after +/// optional whitespace) by `(`. CTE syntax is `WITH AS (...)` — the +/// word after `WITH` is an identifier, not `(`. +fn find_trailing_with_options(upper: &str, query_start: usize) -> Option { + let region = &upper[query_start..]; + // Search backward for the last " WITH" or "WITH" that is followed by "(". + let mut search_end = region.len(); + loop { + let pos = region[..search_end].rfind("WITH")?; + // Verify word boundary before WITH. + if pos > 0 { + let before = region.as_bytes()[pos - 1]; + if before.is_ascii_alphanumeric() || before == b'_' { + search_end = pos; + continue; + } + } + // Check that WITH is followed by `(` (possibly with whitespace). + let after_with = region[pos + 4..].trim_start(); + if after_with.starts_with('(') { + return Some(query_start + pos); + } + // This WITH is a CTE or identifier — keep searching backward. + search_end = pos; + } +} + /// Extract refresh mode from WITH clause. fn extract_refresh_mode(upper: &str, sql: &str) -> String { let with_pos = match upper.rfind("WITH") { diff --git a/nodedb/src/control/server/pgwire/ddl/rls/parse.rs b/nodedb/src/control/server/pgwire/ddl/rls/parse.rs index ac7bcbfc..00aad0e2 100644 --- a/nodedb/src/control/server/pgwire/ddl/rls/parse.rs +++ b/nodedb/src/control/server/pgwire/ddl/rls/parse.rs @@ -70,10 +70,7 @@ pub fn parse_create_rls_policy( .map(|i| using_idx + 1 + i) .unwrap_or(parts.len()); - let predicate_str = parts[using_idx + 1..pred_end] - .join(" ") - .trim_matches(|c: char| c == '(' || c == ')') - .to_string(); + let predicate_str = strip_outer_parens(&parts[using_idx + 1..pred_end].join(" ")); let is_restrictive = parts[pred_end..] .iter() @@ -114,7 +111,7 @@ pub fn parse_create_rls_policy( let field = pred_parts[0]; let op = pred_parts[1]; - let value_str = pred_parts[2..].join(" ").trim_matches('\'').to_string(); + let value_str = strip_single_quotes(&pred_parts[2..].join(" ")); let filter = crate::bridge::scan_filter::ScanFilter { field: field.to_string(), @@ -161,3 +158,48 @@ pub fn parse_create_rls_policy( on_deny, }) } + +/// Strip at most one matched pair of outer single quotes. +/// +/// `"'hello'"` → `"hello"`, `"no_quotes"` → `"no_quotes"`. +fn strip_single_quotes(s: &str) -> String { + let trimmed = s.trim(); + if trimmed.len() >= 2 && trimmed.starts_with('\'') && trimmed.ends_with('\'') { + trimmed[1..trimmed.len() - 1].to_string() + } else { + trimmed.to_string() + } +} + +/// Strip at most one matched pair of outer parentheses. Unlike +/// `trim_matches('(' | ')')`, this preserves balanced inner parens. +/// +/// `"((x > 0) AND (y = 1))"` → `"(x > 0) AND (y = 1)"` +/// `"(x > 0)"` → `"x > 0"` +/// `"x > 0"` → `"x > 0"` (no outer parens) +fn strip_outer_parens(s: &str) -> String { + let trimmed = s.trim(); + if trimmed.starts_with('(') && trimmed.ends_with(')') { + // Verify the outer parens are actually matched (not two separate groups). + let inner = &trimmed[1..trimmed.len() - 1]; + let mut depth = 0i32; + let mut balanced = true; + for ch in inner.chars() { + match ch { + '(' => depth += 1, + ')' => { + depth -= 1; + if depth < 0 { + balanced = false; + break; + } + } + _ => {} + } + } + if balanced && depth == 0 { + return inner.trim().to_string(); + } + } + trimmed.to_string() +} diff --git a/nodedb/src/control/server/pgwire/ddl/trigger/parse.rs b/nodedb/src/control/server/pgwire/ddl/trigger/parse.rs index d70e829a..5797cc16 100644 --- a/nodedb/src/control/server/pgwire/ddl/trigger/parse.rs +++ b/nodedb/src/control/server/pgwire/ddl/trigger/parse.rs @@ -320,21 +320,42 @@ fn extract_dollar_quoted_body(s: &str) -> Option<(&str, String)> { fn find_begin_pos(s: &str) -> Option { let upper = s.to_uppercase(); - let mut search_from = 0; - loop { - let pos = upper[search_from..].find("BEGIN")?; - let abs_pos = search_from + pos; - let before_ok = abs_pos == 0 - || !s.as_bytes()[abs_pos - 1].is_ascii_alphanumeric() - && s.as_bytes()[abs_pos - 1] != b'_'; - let after_pos = abs_pos + 5; - let after_ok = after_pos >= s.len() - || !s.as_bytes()[after_pos].is_ascii_alphanumeric() && s.as_bytes()[after_pos] != b'_'; - if before_ok && after_ok { - return Some(abs_pos); + let bytes = s.as_bytes(); + let mut i = 0; + + while i < bytes.len() { + // Skip single-quoted string literals so that 'BEGIN' inside a + // string (e.g. in a WHEN clause) is not mistaken for the body start. + if bytes[i] == b'\'' { + i += 1; + while i < bytes.len() { + if bytes[i] == b'\'' { + i += 1; + // Handle '' escape. + if i < bytes.len() && bytes[i] == b'\'' { + i += 1; + continue; + } + break; + } + i += 1; + } + continue; + } + + // Check for "BEGIN" keyword at this position. + if i + 5 <= bytes.len() && &upper[i..i + 5] == "BEGIN" { + let before_ok = i == 0 || !bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_'; + let after_pos = i + 5; + let after_ok = after_pos >= bytes.len() + || !bytes[after_pos].is_ascii_alphanumeric() && bytes[after_pos] != b'_'; + if before_ok && after_ok { + return Some(i); + } } - search_from = abs_pos + 5; + i += 1; } + None } #[cfg(test)] diff --git a/nodedb/src/control/server/pgwire/handler/plan.rs b/nodedb/src/control/server/pgwire/handler/plan.rs index 955e9567..8d4c8e47 100644 --- a/nodedb/src/control/server/pgwire/handler/plan.rs +++ b/nodedb/src/control/server/pgwire/handler/plan.rs @@ -124,6 +124,7 @@ pub(super) fn describe_plan(plan: &PhysicalPlan) -> PlanKind { | PhysicalPlan::Query(QueryOp::FacetCounts { .. }) | PhysicalPlan::Query(QueryOp::HashJoin { .. }) | PhysicalPlan::Query(QueryOp::InlineHashJoin { .. }) + | PhysicalPlan::Query(QueryOp::RecursiveScan { .. }) | PhysicalPlan::Graph(GraphOp::Algo { .. }) | PhysicalPlan::Graph(GraphOp::Match { .. }) => PlanKind::MultiRow, From ed2058c77aff7fcce4c0c663a6cb41590fd2e339 Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Thu, 16 Apr 2026 21:53:25 +0800 Subject: [PATCH 6/9] fix(defaults): evaluate arbitrary DEFAULT expressions via const-fold The DEFAULT resolver previously only handled a small set of hard-coded keywords (NOW, UUID, etc.) and returned None for everything else. Add a fallback path that parses the DEFAULT string through sqlparser-rs and runs it through the planner's constant-folding evaluator, enabling DEFAULT values like upper('hello'), 1 + 2, or concat('a', 'b') to resolve at insert time without a Data Plane round-trip. Also replace wrapping integer arithmetic in the const-folder with checked variants to prevent silent overflow on constant expressions. --- nodedb-sql/src/planner/const_fold.rs | 6 ++--- .../sql_plan_convert/value/defaults.rs | 27 ++++++++++++++++++- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/nodedb-sql/src/planner/const_fold.rs b/nodedb-sql/src/planner/const_fold.rs index 26a1fadb..d13f6028 100644 --- a/nodedb-sql/src/planner/const_fold.rs +++ b/nodedb-sql/src/planner/const_fold.rs @@ -64,9 +64,9 @@ pub fn fold_constant(expr: &SqlExpr, registry: &FunctionRegistry) -> Option Option { Some(match (l, op, r) { - (SqlValue::Int(a), BinaryOp::Add, SqlValue::Int(b)) => SqlValue::Int(a + b), - (SqlValue::Int(a), BinaryOp::Sub, SqlValue::Int(b)) => SqlValue::Int(a - b), - (SqlValue::Int(a), BinaryOp::Mul, SqlValue::Int(b)) => SqlValue::Int(a * b), + (SqlValue::Int(a), BinaryOp::Add, SqlValue::Int(b)) => SqlValue::Int(a.checked_add(b)?), + (SqlValue::Int(a), BinaryOp::Sub, SqlValue::Int(b)) => SqlValue::Int(a.checked_sub(b)?), + (SqlValue::Int(a), BinaryOp::Mul, SqlValue::Int(b)) => SqlValue::Int(a.checked_mul(b)?), (SqlValue::Float(a), BinaryOp::Add, SqlValue::Float(b)) => SqlValue::Float(a + b), (SqlValue::Float(a), BinaryOp::Sub, SqlValue::Float(b)) => SqlValue::Float(a - b), (SqlValue::Float(a), BinaryOp::Mul, SqlValue::Float(b)) => SqlValue::Float(a * b), diff --git a/nodedb/src/control/planner/sql_plan_convert/value/defaults.rs b/nodedb/src/control/planner/sql_plan_convert/value/defaults.rs index 09958e90..57d3939f 100644 --- a/nodedb/src/control/planner/sql_plan_convert/value/defaults.rs +++ b/nodedb/src/control/planner/sql_plan_convert/value/defaults.rs @@ -67,5 +67,30 @@ fn parse_parametric_or_literal(expr: &str, upper: &str) -> Option Option { + let sql_expr = nodedb_sql::parse_expr_string(expr).ok()?; + let folded = nodedb_sql::planner::const_fold::fold_constant_default(&sql_expr)?; + Some(sql_value_to_ndb(folded)) +} + +fn sql_value_to_ndb(v: nodedb_sql::types::SqlValue) -> nodedb_types::Value { + use nodedb_sql::types::SqlValue; + match v { + SqlValue::Null => nodedb_types::Value::Null, + SqlValue::Bool(b) => nodedb_types::Value::Bool(b), + SqlValue::Int(i) => nodedb_types::Value::Integer(i), + SqlValue::Float(f) => nodedb_types::Value::Float(f), + SqlValue::String(s) => nodedb_types::Value::String(s), + SqlValue::Bytes(b) => nodedb_types::Value::Bytes(b), + SqlValue::Array(a) => { + nodedb_types::Value::Array(a.into_iter().map(sql_value_to_ndb).collect()) + } + } } From 2a74984378a67ae4d97fb118d9c58b0754ed0f75 Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Thu, 16 Apr 2026 21:53:35 +0800 Subject: [PATCH 7/9] test: add integration tests for parser and planner correctness Cover the bug-fixes landed in this series: - sql_arithmetic_overflow: checked arithmetic in const-fold and eval - sql_default_expressions: DEFAULT parsing and const-fold fallback - sql_join_correctness: non-equi join predicates, NATURAL JOIN error - sql_parser_string_handling: escaped single-quote infinite-loop fix - sql_procedure_cache_safety: plan-cache hash-collision eviction - sql_recursive_cte: join-link tree-traversal correctness - sql_rls_predicate_parse: UTF-8 safe NEW. rewriting in RLS/DDL parsers - sql_trigger_fuel: trigger budget cap prevents unbounded execution - sql_utf8_expressions: multi-byte identifiers in constraint expressions --- nodedb/tests/procedure_execution.rs | 4 +- nodedb/tests/sql_arithmetic_overflow.rs | 198 ++++++++++++ nodedb/tests/sql_default_expressions.rs | 170 +++++++++++ nodedb/tests/sql_join_correctness.rs | 338 +++++++++++++++++++++ nodedb/tests/sql_parser_string_handling.rs | 279 +++++++++++++++++ nodedb/tests/sql_procedure_cache_safety.rs | 106 +++++++ nodedb/tests/sql_recursive_cte.rs | 171 +++++++++++ nodedb/tests/sql_rls_predicate_parse.rs | 120 ++++++++ nodedb/tests/sql_trigger_fuel.rs | 98 ++++++ nodedb/tests/sql_utf8_expressions.rs | 145 +++++++++ 10 files changed, 1627 insertions(+), 2 deletions(-) create mode 100644 nodedb/tests/sql_arithmetic_overflow.rs create mode 100644 nodedb/tests/sql_default_expressions.rs create mode 100644 nodedb/tests/sql_join_correctness.rs create mode 100644 nodedb/tests/sql_parser_string_handling.rs create mode 100644 nodedb/tests/sql_procedure_cache_safety.rs create mode 100644 nodedb/tests/sql_recursive_cte.rs create mode 100644 nodedb/tests/sql_rls_predicate_parse.rs create mode 100644 nodedb/tests/sql_trigger_fuel.rs create mode 100644 nodedb/tests/sql_utf8_expressions.rs diff --git a/nodedb/tests/procedure_execution.rs b/nodedb/tests/procedure_execution.rs index 5c943d8d..faae57e3 100644 --- a/nodedb/tests/procedure_execution.rs +++ b/nodedb/tests/procedure_execution.rs @@ -390,10 +390,10 @@ fn execution_budget_exhaustion() { } #[test] -fn execution_budget_unlimited() { +fn execution_budget_trigger_default() { use nodedb::control::planner::procedural::executor::fuel::ExecutionBudget; - let mut budget = ExecutionBudget::unlimited(); + let mut budget = ExecutionBudget::trigger_default(); for _ in 0..1000 { assert!(budget.consume_iteration().is_ok()); } diff --git a/nodedb/tests/sql_arithmetic_overflow.rs b/nodedb/tests/sql_arithmetic_overflow.rs new file mode 100644 index 00000000..d28685bf --- /dev/null +++ b/nodedb/tests/sql_arithmetic_overflow.rs @@ -0,0 +1,198 @@ +//! Integration coverage for integer overflow handling in SQL expressions. +//! +//! The const-folder and procedural executor must detect integer overflow +//! and return an error rather than panicking (debug) or silently wrapping +//! (release). Float divide-by-zero must return an error, not ±Inf. + +mod common; + +use common::pgwire_harness::TestServer; + +// --------------------------------------------------------------------------- +// Const-folder overflow (nodedb-sql planner) +// --------------------------------------------------------------------------- + +/// `i64::MAX + 1` in a constant expression must not panic (debug) or wrap +/// to a negative number (release). An error or null are both acceptable. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn const_fold_addition_overflow_does_not_wrap() { + let server = TestServer::start().await; + + let result = server.query_text("SELECT 9223372036854775807 + 1").await; + + match result { + Err(_) => { /* error is acceptable */ } + Ok(rows) => { + if let Some(val) = rows.first() { + assert!( + !val.contains("-9223372036854775808"), + "i64::MAX + 1 must not silently wrap to i64::MIN: got {val}" + ); + } + } + } +} + +/// `i64::MAX * 2` must not panic or wrap. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn const_fold_multiplication_overflow_does_not_wrap() { + let server = TestServer::start().await; + + let result = server.query_text("SELECT 9223372036854775807 * 2").await; + + match result { + Err(_) => { /* error is acceptable */ } + Ok(rows) => { + if let Some(val) = rows.first() { + // Wrapped value would be -2. + assert!( + !val.contains("\"-2\""), + "i64::MAX * 2 must not silently wrap: got {val}" + ); + } + } + } +} + +/// `i64::MIN - 1` must not panic or wrap. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn const_fold_subtraction_overflow_does_not_wrap() { + let server = TestServer::start().await; + + let result = server.query_text("SELECT -9223372036854775808 - 1").await; + + match result { + Err(_) => { /* error is acceptable */ } + Ok(rows) => { + if let Some(val) = rows.first() { + // Wrapped value would be i64::MAX = 9223372036854775807. + assert!( + !val.contains("9223372036854775807"), + "i64::MIN - 1 must not silently wrap to i64::MAX: got {val}" + ); + } + } + } +} + +// --------------------------------------------------------------------------- +// Const-folder in INSERT context +// --------------------------------------------------------------------------- + +/// Overflow in an INSERT VALUES expression should be caught before storage. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn const_fold_overflow_in_insert_values() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION overflow_tbl TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + v BIGINT)", + ) + .await + .unwrap(); + + let result = server + .exec("INSERT INTO overflow_tbl (id, v) VALUES ('k', 9223372036854775807 * 2)") + .await; + + assert!( + result.is_err(), + "INSERT with overflowing constant should fail, not store wrapped value" + ); +} + +// --------------------------------------------------------------------------- +// Procedural executor overflow (triggers / DO blocks) +// --------------------------------------------------------------------------- + +/// Integer overflow in a DO block's variable arithmetic should error, not +/// panic or silently wrap. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn procedural_eval_integer_overflow_returns_error() { + let server = TestServer::start().await; + + server + .exec("CREATE COLLECTION proc_ov TYPE DOCUMENT STRICT (id TEXT PRIMARY KEY, v BIGINT)") + .await + .unwrap(); + + // A DO block (or procedure) that overflows during evaluation. + let result = server + .exec( + "DO $$ \ + DECLARE x BIGINT := 9223372036854775807; \ + BEGIN \ + x := x + 1; \ + INSERT INTO proc_ov (id, v) VALUES ('k', x); \ + END $$", + ) + .await; + + assert!( + result.is_err(), + "integer overflow in procedural block should error, not wrap" + ); +} + +/// `i64::MIN / -1` is undefined behavior in two's complement and must error. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn procedural_eval_min_div_neg1_returns_error() { + let server = TestServer::start().await; + + let result = server + .exec( + "DO $$ \ + DECLARE x BIGINT := -9223372036854775808; \ + DECLARE y BIGINT; \ + BEGIN \ + y := x / -1; \ + END $$", + ) + .await; + + assert!( + result.is_err(), + "i64::MIN / -1 in procedural block should error, not panic" + ); +} + +/// Float divide by negative zero should return an error or NULL, not ±Inf. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn procedural_eval_float_div_neg_zero() { + let server = TestServer::start().await; + + server + .exec("CREATE COLLECTION fdiv TYPE DOCUMENT STRICT (id TEXT PRIMARY KEY, v FLOAT)") + .await + .unwrap(); + + // -0.0 passes the `!= 0.0` guard in f64 comparison but produces -inf. + let result = server + .exec( + "DO $$ \ + DECLARE a FLOAT := 1.0; \ + DECLARE b FLOAT := -0.0; \ + DECLARE c FLOAT; \ + BEGIN \ + c := a / b; \ + INSERT INTO fdiv (id, v) VALUES ('k', c); \ + END $$", + ) + .await; + + // Either error, or if it succeeds, the stored value must not be infinity. + if result.is_ok() { + let rows = server + .query_text("SELECT v FROM fdiv WHERE id = 'k'") + .await + .unwrap(); + if let Some(val) = rows.first() { + assert!( + !val.to_lowercase().contains("inf"), + "float / -0.0 should not produce infinity: got {val}" + ); + } + } +} diff --git a/nodedb/tests/sql_default_expressions.rs b/nodedb/tests/sql_default_expressions.rs new file mode 100644 index 00000000..4091f32f --- /dev/null +++ b/nodedb/tests/sql_default_expressions.rs @@ -0,0 +1,170 @@ +//! Integration coverage for DEFAULT expression evaluation in INSERT. +//! +//! The planner's `evaluate_default_expr` recognizes only a fixed keyword list +//! (UUID_V7, NOW(), NANOID, literals). Any other expression returns None, +//! causing the column to be silently omitted. These tests verify that +//! expression-based defaults are evaluated, not dropped. + +mod common; + +use common::pgwire_harness::TestServer; + +/// `DEFAULT upper('x')` — a scalar function call as a default value. +/// The planner should evaluate this rather than dropping the column. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn default_scalar_function_upper() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION def_fn TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + a TEXT DEFAULT upper('x'))", + ) + .await + .unwrap(); + + server + .exec("INSERT INTO def_fn (id) VALUES ('k1')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT a FROM def_fn WHERE id = 'k1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1, "row should exist"); + // The default should produce 'X'. If the column was silently dropped, + // the value will be null/absent. + assert!( + rows[0].contains('X'), + "DEFAULT upper('x') should produce 'X', got {:?}", + rows[0] + ); +} + +/// `DEFAULT lower('HELLO')` — another scalar function. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn default_scalar_function_lower() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION def_lower TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + tag TEXT DEFAULT lower('HELLO'))", + ) + .await + .unwrap(); + + server + .exec("INSERT INTO def_lower (id) VALUES ('k1')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT tag FROM def_lower WHERE id = 'k1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1); + assert!( + rows[0].contains("hello"), + "DEFAULT lower('HELLO') should produce 'hello', got {:?}", + rows[0] + ); +} + +/// `DEFAULT 1 + 2` — a binary arithmetic expression as default. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn default_arithmetic_expression() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION def_arith TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + v INT DEFAULT 1 + 2)", + ) + .await + .unwrap(); + + server + .exec("INSERT INTO def_arith (id) VALUES ('k1')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT v FROM def_arith WHERE id = 'k1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1); + assert!( + rows[0].contains('3'), + "DEFAULT 1 + 2 should produce 3, got {:?}", + rows[0] + ); +} + +/// `DEFAULT concat('a', 'b')` — a multi-arg function. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn default_concat_function() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION def_concat TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + label TEXT DEFAULT concat('hello', '_', 'world'))", + ) + .await + .unwrap(); + + server + .exec("INSERT INTO def_concat (id) VALUES ('k1')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT label FROM def_concat WHERE id = 'k1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1); + assert!( + rows[0].contains("hello_world"), + "DEFAULT concat should produce 'hello_world', got {:?}", + rows[0] + ); +} + +/// Verify that recognized defaults (literal string, NOW(), UUID_V7) still work. +/// This is a baseline — not a new bug, just ensures we don't regress. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn default_recognized_expressions_still_work() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION def_known TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + status TEXT DEFAULT 'active', \ + uid TEXT DEFAULT UUID_V7)", + ) + .await + .unwrap(); + + server + .exec("INSERT INTO def_known (id) VALUES ('k1')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT status FROM def_known WHERE id = 'k1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1); + assert!( + rows[0].contains("active"), + "DEFAULT 'active' should work: got {:?}", + rows[0] + ); +} diff --git a/nodedb/tests/sql_join_correctness.rs b/nodedb/tests/sql_join_correctness.rs new file mode 100644 index 00000000..7dd3a86a --- /dev/null +++ b/nodedb/tests/sql_join_correctness.rs @@ -0,0 +1,338 @@ +//! Integration coverage for SQL JOIN correctness. +//! +//! Covers: RIGHT JOIN plan conversion (inline swap), multi-predicate join +//! conditions (non-equi predicates), and NATURAL JOIN handling. + +mod common; + +use common::pgwire_harness::TestServer; + +// --------------------------------------------------------------------------- +// Helper: set up three related tables for join tests +// --------------------------------------------------------------------------- + +async fn setup_join_tables(server: &TestServer) { + server + .exec( + "CREATE COLLECTION j_t1 TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + name TEXT, \ + x INT)", + ) + .await + .unwrap(); + server + .exec( + "CREATE COLLECTION j_t2 TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + t1_id TEXT, \ + y INT)", + ) + .await + .unwrap(); + server + .exec( + "CREATE COLLECTION j_t3 TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + t2_id TEXT, \ + z INT)", + ) + .await + .unwrap(); + + // t1 data + server + .exec("INSERT INTO j_t1 (id, name, x) VALUES ('a', 'Alice', 10)") + .await + .unwrap(); + server + .exec("INSERT INTO j_t1 (id, name, x) VALUES ('b', 'Bob', 20)") + .await + .unwrap(); + + // t2 data — references t1 + server + .exec("INSERT INTO j_t2 (id, t1_id, y) VALUES ('p', 'a', 100)") + .await + .unwrap(); + server + .exec("INSERT INTO j_t2 (id, t1_id, y) VALUES ('q', 'a', 200)") + .await + .unwrap(); + + // t3 data — references t2, plus one unmatched + server + .exec("INSERT INTO j_t3 (id, t2_id, z) VALUES ('x', 'p', 1)") + .await + .unwrap(); + server + .exec("INSERT INTO j_t3 (id, t2_id, z) VALUES ('y', 'q', 2)") + .await + .unwrap(); + server + .exec("INSERT INTO j_t3 (id, t2_id, z) VALUES ('z', 'NONE', 3)") + .await + .unwrap(); +} + +// --------------------------------------------------------------------------- +// RIGHT JOIN with nested join on the left +// --------------------------------------------------------------------------- + +/// A RIGHT JOIN where the left side is itself a join. The planner rewrites +/// RIGHT → LEFT by swapping sides, but must also swap inline_left/inline_right. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn right_join_with_nested_left_join_returns_correct_rows() { + let server = TestServer::start().await; + setup_join_tables(&server).await; + + // t3 has 3 rows. RIGHT JOIN means all t3 rows appear, with NULLs for + // unmatched left-side (t1 JOIN t2) rows. + let rows = server + .query_text( + "SELECT j_t3.id FROM j_t1 \ + INNER JOIN j_t2 ON j_t1.id = j_t2.t1_id \ + RIGHT JOIN j_t3 ON j_t2.id = j_t3.t2_id", + ) + .await + .unwrap(); + + // All 3 t3 rows must appear (x, y, z). z has no match in the left join + // but RIGHT JOIN preserves it. + assert_eq!( + rows.len(), + 3, + "RIGHT JOIN should preserve all right-side rows: got {rows:?}" + ); +} + +/// Simple RIGHT JOIN (no nested join) as a baseline. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn simple_right_join_preserves_right_rows() { + let server = TestServer::start().await; + setup_join_tables(&server).await; + + let rows = server + .query_text( + "SELECT j_t3.id FROM j_t2 \ + RIGHT JOIN j_t3 ON j_t2.id = j_t3.t2_id", + ) + .await + .unwrap(); + + assert_eq!( + rows.len(), + 3, + "simple RIGHT JOIN should return all 3 t3 rows: got {rows:?}" + ); +} + +// --------------------------------------------------------------------------- +// Multiple non-equi predicates in JOIN ON +// --------------------------------------------------------------------------- + +/// All non-equi predicates in a JOIN ON clause must be preserved. +/// Only keeping the first one silently drops filter conditions. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn join_preserves_all_non_equi_predicates() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION jn_left TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + x INT, \ + y INT)", + ) + .await + .unwrap(); + server + .exec( + "CREATE COLLECTION jn_right TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + x INT, \ + y INT)", + ) + .await + .unwrap(); + + server + .exec("INSERT INTO jn_left (id, x, y) VALUES ('L1', 10, 5)") + .await + .unwrap(); + server + .exec("INSERT INTO jn_left (id, x, y) VALUES ('L2', 10, 50)") + .await + .unwrap(); + + server + .exec("INSERT INTO jn_right (id, x, y) VALUES ('R1', 10, 20)") + .await + .unwrap(); + + // Both non-equi conditions must apply: + // L1: x=10 matches R1.x=10, L1.x(10) > R1.x(10) is FALSE → no match + // Actually let me redesign: equi on id won't work. Let's use a simpler setup. + // L1: x=10, y=5. R1: x=10, y=20. + // Condition: jn_left.x = jn_right.x AND jn_left.y < jn_right.y AND jn_left.y > 0 + // L1 matches: x=x, 5 < 20, 5 > 0 → yes + // L2 matches: x=x, 50 < 20 is FALSE → no + + let rows = server + .query_text( + "SELECT jn_left.id FROM jn_left \ + JOIN jn_right ON jn_left.x = jn_right.x \ + AND jn_left.y < jn_right.y \ + AND jn_left.y > 0", + ) + .await + .unwrap(); + + assert_eq!( + rows.len(), + 1, + "only L1 should match all three join conditions: got {rows:?}" + ); + assert!( + rows[0].contains("L1"), + "matched row should be L1: got {:?}", + rows[0] + ); +} + +/// If the second non-equi predicate is dropped, L2 would incorrectly appear. +/// This is a regression guard: assert L2 is NOT in the result. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn join_does_not_drop_second_non_equi_predicate() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION jd_a TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + val INT, \ + lo INT, \ + hi INT)", + ) + .await + .unwrap(); + server + .exec( + "CREATE COLLECTION jd_b TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + val INT, \ + bound INT)", + ) + .await + .unwrap(); + + // A: val=1, lo=0, hi=100. B: val=1, bound=50. + // Condition: a.val = b.val AND a.lo < b.bound AND a.hi > b.bound + // → 0 < 50 AND 100 > 50 → match. + server + .exec("INSERT INTO jd_a (id, val, lo, hi) VALUES ('a1', 1, 0, 100)") + .await + .unwrap(); + // A2: val=1, lo=0, hi=10. + // → 0 < 50 AND 10 > 50 → no match (second non-equi fails). + server + .exec("INSERT INTO jd_a (id, val, lo, hi) VALUES ('a2', 1, 0, 10)") + .await + .unwrap(); + + server + .exec("INSERT INTO jd_b (id, val, bound) VALUES ('b1', 1, 50)") + .await + .unwrap(); + + let rows = server + .query_text( + "SELECT jd_a.id FROM jd_a \ + JOIN jd_b ON jd_a.val = jd_b.val \ + AND jd_a.lo < jd_b.bound \ + AND jd_a.hi > jd_b.bound", + ) + .await + .unwrap(); + + assert_eq!(rows.len(), 1, "only a1 should match: got {rows:?}"); + assert!( + rows[0].contains("a1"), + "matched row should be a1: got {:?}", + rows[0] + ); + // Regression guard: if the second non-equi (`hi > bound`) was dropped, + // a2 would also appear. + assert!( + !rows.iter().any(|r| r.contains("a2")), + "a2 should NOT match — hi=10 is not > bound=50: got {rows:?}" + ); +} + +// --------------------------------------------------------------------------- +// NATURAL JOIN +// --------------------------------------------------------------------------- + +/// NATURAL JOIN should either compute the shared column set and produce a +/// correct equi-join, or return an explicit unsupported error. It must NOT +/// silently produce a cartesian product. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn natural_join_is_not_cartesian() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION nat_a TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + shared_col TEXT)", + ) + .await + .unwrap(); + server + .exec( + "CREATE COLLECTION nat_b TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + shared_col TEXT)", + ) + .await + .unwrap(); + + server + .exec("INSERT INTO nat_a (id, shared_col) VALUES ('a1', 'x')") + .await + .unwrap(); + server + .exec("INSERT INTO nat_a (id, shared_col) VALUES ('a2', 'y')") + .await + .unwrap(); + + server + .exec("INSERT INTO nat_b (id, shared_col) VALUES ('b1', 'x')") + .await + .unwrap(); + + let result = server + .query_text("SELECT nat_a.id FROM nat_a NATURAL JOIN nat_b") + .await; + + match result { + Ok(rows) => { + // If NATURAL JOIN is supported, shared_col='x' should produce + // a match on a1↔b1 only. Cartesian would give 2 rows. + assert!( + rows.len() <= 1, + "NATURAL JOIN should match on shared_col, not produce cartesian ({} rows): {rows:?}", + rows.len() + ); + } + Err(msg) => { + // An explicit "unsupported" error is acceptable. + assert!( + msg.to_lowercase().contains("natural") + || msg.to_lowercase().contains("unsupported") + || msg.to_lowercase().contains("not supported"), + "error should mention NATURAL JOIN: {msg}" + ); + } + } +} diff --git a/nodedb/tests/sql_parser_string_handling.rs b/nodedb/tests/sql_parser_string_handling.rs new file mode 100644 index 00000000..17728ee6 --- /dev/null +++ b/nodedb/tests/sql_parser_string_handling.rs @@ -0,0 +1,279 @@ +//! Integration coverage for SQL parsers handling string literals correctly. +//! +//! Hand-rolled parsers must skip over single-quoted string literals when +//! scanning for structural delimiters (`BEGIN`, `)`, `,`, `WITH`, `{}`). +//! These tests verify that embedded keywords and special characters inside +//! string values do not corrupt the parse. + +mod common; + +use common::pgwire_harness::TestServer; + +// --------------------------------------------------------------------------- +// 1. CREATE TRIGGER: `BEGIN` inside a string literal in WHEN clause +// --------------------------------------------------------------------------- + +/// A WHEN clause containing the word 'BEGIN' inside a string literal must not +/// confuse the header/body split. The trigger should parse correctly and the +/// body should execute as written. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn trigger_begin_inside_string_literal_does_not_corrupt_parse() { + let server = TestServer::start().await; + + server + .exec("CREATE COLLECTION items TYPE DOCUMENT STRICT (id TEXT PRIMARY KEY, label TEXT)") + .await + .unwrap(); + + server.exec("CREATE COLLECTION audit_log").await.unwrap(); + + // The WHEN clause references the string 'BEGIN' — the parser must not + // treat this as the body start. + let result = server + .exec( + "CREATE TRIGGER tr_begin BEFORE INSERT ON items FOR EACH ROW \ + WHEN (NEW.label = 'BEGIN' OR NEW.label = 'open') \ + BEGIN INSERT INTO audit_log { id: NEW.id, note: 'triggered' }; END", + ) + .await; + assert!( + result.is_ok(), + "CREATE TRIGGER with 'BEGIN' in WHEN string should succeed: {:?}", + result + ); + + // The trigger should not interfere with a normal insert. + server + .exec("INSERT INTO items (id, label) VALUES ('i1', 'hello')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT id FROM items WHERE id = 'i1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1, "insert should succeed through the trigger"); +} + +// --------------------------------------------------------------------------- +// 2. INSERT: `)` inside a string literal in VALUES +// --------------------------------------------------------------------------- + +/// A string value containing `)` must not break the VALUES pre-parse that +/// scans for the closing paren. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn insert_paren_inside_string_value_parses_correctly() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION msgs TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + body TEXT)", + ) + .await + .unwrap(); + + server + .exec("INSERT INTO msgs (id, body) VALUES ('m1', 'hello)world')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT body FROM msgs WHERE id = 'm1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1, "insert with ) in string should succeed"); + assert!( + rows[0].contains("hello)world"), + "value should preserve the literal paren: got {:?}", + rows[0] + ); +} + +/// Multiple special characters inside string values in a multi-column INSERT. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn insert_multiple_parens_in_strings() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION notes TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + a TEXT, \ + b TEXT)", + ) + .await + .unwrap(); + + server + .exec("INSERT INTO notes (id, a, b) VALUES ('n1', 'a)b', 'c(d')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT a FROM notes WHERE id = 'n1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1); + assert!(rows[0].contains("a)b"), "got {:?}", rows[0]); +} + +// --------------------------------------------------------------------------- +// 3. split_values: quote tracking inside brackets/arrays +// --------------------------------------------------------------------------- + +/// An array literal containing strings with `)` must not confuse the value +/// splitter's bracket-depth tracking. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn insert_array_with_paren_in_string_element() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION arr_test TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + tags TEXT, \ + count INT)", + ) + .await + .unwrap(); + + // ARRAY['x)y', 'z'] — the `)` inside the first element must not + // decrement the bracket counter. + server + .exec("INSERT INTO arr_test (id, tags, count) VALUES ('a1', ARRAY['x)y', 'z'], 42)") + .await + .unwrap(); + + let rows = server + .query_text("SELECT count FROM arr_test WHERE id = 'a1'") + .await + .unwrap(); + assert_eq!( + rows.len(), + 1, + "insert with array containing ')' should succeed" + ); + assert!( + rows[0].contains("42"), + "second column value should be 42, got {:?}", + rows[0] + ); +} + +// --------------------------------------------------------------------------- +// 5. Object-literal `{ }` rewriter: `''`-escaped quotes +// --------------------------------------------------------------------------- + +/// The `{ }` preprocessor must handle SQL-escaped single quotes (`''`). +/// `'it''s'` is a valid SQL string containing an apostrophe. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn object_literal_with_escaped_quote_parses_correctly() { + let server = TestServer::start().await; + + server.exec("CREATE COLLECTION esc_test").await.unwrap(); + + // The `{ note: 'it''s' }` syntax goes through `find_matching_brace` + // which must correctly handle the `''` escape. + server + .exec("INSERT INTO esc_test { id: 'e1', note: 'it''s fine' }") + .await + .unwrap(); + + let rows = server + .query_text("SELECT * FROM esc_test WHERE id = 'e1'") + .await + .unwrap(); + assert_eq!( + rows.len(), + 1, + "insert with escaped quote in object literal should succeed" + ); + assert!( + rows[0].contains("it's fine") || rows[0].contains("it''s fine"), + "value should contain the apostrophe: got {:?}", + rows[0] + ); +} + +/// Escaped quotes adjacent to the closing brace. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn object_literal_escaped_quote_near_brace() { + let server = TestServer::start().await; + + server.exec("CREATE COLLECTION esc_test2").await.unwrap(); + + server + .exec("INSERT INTO esc_test2 { id: 'e2', val: 'end''s}' }") + .await + .unwrap(); + + let rows = server + .query_text("SELECT * FROM esc_test2 WHERE id = 'e2'") + .await + .unwrap(); + assert_eq!( + rows.len(), + 1, + "escaped quote near brace should parse correctly" + ); +} + +// --------------------------------------------------------------------------- +// 7. CREATE MATERIALIZED VIEW: `WITH` inside SELECT body +// --------------------------------------------------------------------------- + +/// A materialized view whose SELECT contains a CTE (`WITH cte AS (...)`) +/// must not have the CTE keyword mistaken for the options clause. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn materialized_view_with_cte_in_select() { + let server = TestServer::start().await; + + server + .exec("CREATE COLLECTION mv_src TYPE DOCUMENT STRICT (id TEXT PRIMARY KEY, val INT)") + .await + .unwrap(); + server + .exec("INSERT INTO mv_src (id, val) VALUES ('r1', 10)") + .await + .unwrap(); + + // The SELECT uses a CTE — `WITH s AS (...)`. The parser must not + // truncate the query at the CTE's `WITH`. + let result = server + .exec( + "CREATE MATERIALIZED VIEW mv_cte ON mv_src AS \ + WITH s AS (SELECT id, val FROM mv_src) SELECT id, val FROM s", + ) + .await; + assert!( + result.is_ok(), + "CREATE MATERIALIZED VIEW with CTE should succeed: {:?}", + result + ); +} + +/// A materialized view whose SELECT contains a column aliased with `with_`. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn materialized_view_with_keyword_in_column_alias() { + let server = TestServer::start().await; + + server + .exec("CREATE COLLECTION mv_src2 TYPE DOCUMENT STRICT (id TEXT PRIMARY KEY, val INT)") + .await + .unwrap(); + + // `with_tax` as a column alias should not be mistaken for `WITH` options. + let result = server + .exec( + "CREATE MATERIALIZED VIEW mv_alias ON mv_src2 AS \ + SELECT id, val AS with_tax FROM mv_src2", + ) + .await; + assert!( + result.is_ok(), + "column alias starting with 'with_' should not break parse: {:?}", + result + ); +} diff --git a/nodedb/tests/sql_procedure_cache_safety.rs b/nodedb/tests/sql_procedure_cache_safety.rs new file mode 100644 index 00000000..6461e174 --- /dev/null +++ b/nodedb/tests/sql_procedure_cache_safety.rs @@ -0,0 +1,106 @@ +//! Integration coverage for procedure/trigger body cache correctness. +//! +//! The ProcedureBlockCache keys on a 64-bit hash of the body SQL without +//! storing the source for equality verification. A hash collision returns +//! the wrong compiled block. This test verifies that two different trigger +//! bodies always execute their own logic, not each other's. + +mod common; + +use std::time::Duration; + +use common::pgwire_harness::TestServer; + +/// Two triggers with different bodies on different collections must each +/// execute their own body, never the other's. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn different_trigger_bodies_execute_independently() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION cache_a TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + val TEXT DEFAULT 'none')", + ) + .await + .unwrap(); + server + .exec( + "CREATE COLLECTION cache_b TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + val TEXT DEFAULT 'none')", + ) + .await + .unwrap(); + // Audit log to capture trigger side effects. + server.exec("CREATE COLLECTION trigger_log").await.unwrap(); + + // Trigger A: logs 'trigger_a_fired'. + server + .exec( + "CREATE TRIGGER trg_a AFTER INSERT ON cache_a FOR EACH ROW \ + BEGIN \ + INSERT INTO trigger_log { id: 'log_a', source: 'trigger_a_fired' }; \ + END", + ) + .await + .unwrap(); + + // Trigger B: logs 'trigger_b_fired' — different body. + server + .exec( + "CREATE TRIGGER trg_b AFTER INSERT ON cache_b FOR EACH ROW \ + BEGIN \ + INSERT INTO trigger_log { id: 'log_b', source: 'trigger_b_fired' }; \ + END", + ) + .await + .unwrap(); + + // Fire trigger A. + server + .exec("INSERT INTO cache_a (id) VALUES ('a1')") + .await + .unwrap(); + // Wait for async AFTER trigger dispatch. + tokio::time::sleep(Duration::from_millis(500)).await; + + // Fire trigger B. + server + .exec("INSERT INTO cache_b (id) VALUES ('b1')") + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(500)).await; + + // Verify trigger A logged 'trigger_a_fired', not 'trigger_b_fired'. + let log_a = server + .query_text("SELECT * FROM trigger_log WHERE id = 'log_a'") + .await + .unwrap(); + assert_eq!(log_a.len(), 1, "trigger A should have logged: {log_a:?}"); + assert!( + log_a[0].contains("trigger_a_fired"), + "trigger A should execute its own body, not B's: {:?}", + log_a[0] + ); + // Regression guard: if cache collision occurred, A's log entry would + // contain B's message. + assert!( + !log_a[0].contains("trigger_b_fired"), + "trigger A must not execute trigger B's body (cache collision): {:?}", + log_a[0] + ); + + // Verify trigger B logged its own message. + let log_b = server + .query_text("SELECT * FROM trigger_log WHERE id = 'log_b'") + .await + .unwrap(); + assert_eq!(log_b.len(), 1, "trigger B should have logged: {log_b:?}"); + assert!( + log_b[0].contains("trigger_b_fired"), + "trigger B should execute its own body: {:?}", + log_b[0] + ); +} diff --git a/nodedb/tests/sql_recursive_cte.rs b/nodedb/tests/sql_recursive_cte.rs new file mode 100644 index 00000000..b9ee387b --- /dev/null +++ b/nodedb/tests/sql_recursive_cte.rs @@ -0,0 +1,171 @@ +//! Integration coverage for WITH RECURSIVE (recursive CTEs). +//! +//! The planner must preserve the recursive structure (base + recursive branch +//! with self-reference), respect UNION vs UNION ALL, and allow configurable +//! max iteration depth. + +mod common; + +use common::pgwire_harness::TestServer; + +/// Basic recursive CTE: generate numbers 1..5 using UNION ALL. +/// The recursive branch references the CTE and increments. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn recursive_cte_generates_sequence() { + let server = TestServer::start().await; + + // We need a base collection for the query to target. + server + .exec("CREATE COLLECTION cte_dummy TYPE DOCUMENT STRICT (id TEXT PRIMARY KEY)") + .await + .unwrap(); + + let rows = server + .query_text( + "WITH RECURSIVE c(n) AS (\ + SELECT 1 \ + UNION ALL \ + SELECT n + 1 FROM c WHERE n < 5\ + ) \ + SELECT n FROM c", + ) + .await; + + match rows { + Ok(values) => { + assert_eq!( + values.len(), + 5, + "recursive CTE should produce 5 rows (1..5): got {values:?}" + ); + // Values should be 1, 2, 3, 4, 5 in some order. + let nums: Vec = values + .iter() + .filter_map(|v| v.trim().parse().ok()) + .collect(); + let mut sorted = nums.clone(); + sorted.sort(); + assert_eq!( + sorted, + vec![1, 2, 3, 4, 5], + "recursive CTE should produce 1..5: got {nums:?}" + ); + } + Err(msg) => { + // If recursive CTEs are not yet supported, an explicit error is acceptable + // but a silent empty result or wrong collection is not. + assert!( + msg.to_lowercase().contains("recursive") + || msg.to_lowercase().contains("not supported") + || msg.to_lowercase().contains("unsupported"), + "error should mention recursive CTE: {msg}" + ); + } + } +} + +/// UNION ALL in recursive CTE must preserve duplicates. +/// UNION (without ALL) should deduplicate. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn recursive_cte_union_all_preserves_duplicates() { + let server = TestServer::start().await; + + server + .exec("CREATE COLLECTION cte_dup TYPE DOCUMENT STRICT (id TEXT PRIMARY KEY, grp TEXT)") + .await + .unwrap(); + + // Base: two rows with the same grp value. Recursive: no recursion (WHERE false). + // UNION ALL should keep both base rows; UNION would deduplicate. + let result_all = server + .query_text( + "WITH RECURSIVE c(val) AS (\ + SELECT 1 UNION ALL SELECT 1\ + ) \ + SELECT val FROM c", + ) + .await; + + match result_all { + Ok(rows) => { + assert_eq!( + rows.len(), + 2, + "UNION ALL should preserve duplicate base rows: got {rows:?}" + ); + } + Err(_) => { + // Acceptable if recursive CTEs aren't supported yet — but the + // test failing here captures the coverage gap. + } + } +} + +/// Recursive CTE over actual table data: tree traversal. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn recursive_cte_tree_traversal() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION tree TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + parent_id TEXT)", + ) + .await + .unwrap(); + + // Build a small tree: root → child1 → grandchild + server + .exec("INSERT INTO tree (id, parent_id) VALUES ('root', NULL)") + .await + .unwrap(); + server + .exec("INSERT INTO tree (id, parent_id) VALUES ('child1', 'root')") + .await + .unwrap(); + server + .exec("INSERT INTO tree (id, parent_id) VALUES ('grandchild', 'child1')") + .await + .unwrap(); + server + .exec("INSERT INTO tree (id, parent_id) VALUES ('orphan', 'missing')") + .await + .unwrap(); + + // Traverse descendants of 'root'. + let rows = server + .query_text( + "WITH RECURSIVE descendants(id) AS (\ + SELECT id FROM tree WHERE id = 'root' \ + UNION ALL \ + SELECT t.id FROM tree t \ + INNER JOIN descendants d ON t.parent_id = d.id\ + ) \ + SELECT id FROM descendants", + ) + .await; + + match rows { + Ok(values) => { + // Should find: root, child1, grandchild (3 rows). NOT orphan. + assert_eq!( + values.len(), + 3, + "tree traversal should find root + child1 + grandchild: got {values:?}" + ); + assert!( + !values.iter().any(|v| v.contains("orphan")), + "orphan should not appear in descendants of root: {values:?}" + ); + } + Err(msg) => { + // An explicit unsupported error is acceptable — the test captures the gap. + assert!( + msg.to_lowercase().contains("recursive") + || msg.to_lowercase().contains("not supported"), + "unexpected error: {msg}" + ); + } + } +} diff --git a/nodedb/tests/sql_rls_predicate_parse.rs b/nodedb/tests/sql_rls_predicate_parse.rs new file mode 100644 index 00000000..0361073d --- /dev/null +++ b/nodedb/tests/sql_rls_predicate_parse.rs @@ -0,0 +1,120 @@ +//! Integration coverage for RLS policy predicate parsing. +//! +//! The USING clause predicate must preserve balanced parentheses faithfully. +//! `trim_matches('(' | ')')` strips all leading/trailing parens — not just +//! one matched pair — corrupting nested boolean expressions. + +mod common; + +use common::pgwire_harness::TestServer; + +/// Double-wrapped balanced parens `((x > 0) AND (y = 1))` must be stored +/// correctly and enforced at query time. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn rls_policy_nested_parens_preserved() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION rls_items TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + x INT, \ + y INT)", + ) + .await + .unwrap(); + + // Create a policy with nested parens in USING predicate. + let result = server + .exec( + "CREATE RLS POLICY pos_check ON rls_items FOR READ \ + USING ((x > 0) AND (y = 1))", + ) + .await; + assert!( + result.is_ok(), + "CREATE RLS POLICY with nested parens should succeed: {:?}", + result + ); + + // Insert rows: one matching, one not. + server + .exec("INSERT INTO rls_items (id, x, y) VALUES ('ok', 5, 1)") + .await + .unwrap(); + server + .exec("INSERT INTO rls_items (id, x, y) VALUES ('bad', -1, 1)") + .await + .unwrap(); + + // Create a non-superuser to test RLS enforcement. + server + .exec("CREATE USER rls_reader WITH PASSWORD 'pass' ROLE readonly") + .await + .unwrap(); + + let (reader, _handle) = server.connect_as("rls_reader", "pass").await.unwrap(); + let result = reader.simple_query("SELECT id FROM rls_items").await; + match result { + Ok(msgs) => { + let mut ids: Vec = Vec::new(); + for msg in &msgs { + if let tokio_postgres::SimpleQueryMessage::Row(row) = msg { + ids.push(row.get(0).unwrap_or("").to_string()); + } + } + // The policy should allow `ok` (x=5 > 0, y=1) and reject `bad` (x=-1). + assert!( + ids.iter().any(|id| id.contains("ok")), + "row with x=5,y=1 should be visible under RLS: {ids:?}" + ); + assert!( + !ids.iter().any(|id| id.contains("bad")), + "row with x=-1 should be filtered by RLS: {ids:?}" + ); + } + Err(e) => { + // Extract detailed error if available. + let detail = if let Some(db_err) = e.as_db_error() { + format!( + "{}: {} (SQLSTATE {})", + db_err.severity(), + db_err.message(), + db_err.code().code() + ) + } else { + format!("{e:?}") + }; + panic!("RLS query should not error — predicate may be corrupted: {detail}"); + } + } +} + +/// Triple-nested parens should not be over-stripped. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn rls_policy_triple_nested_parens() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION rls_deep TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + a INT, \ + b INT, \ + c INT)", + ) + .await + .unwrap(); + + let result = server + .exec( + "CREATE RLS POLICY deep_check ON rls_deep FOR READ \ + USING (((a > 0) AND (b > 0)) OR (c = 99))", + ) + .await; + assert!( + result.is_ok(), + "deeply nested RLS predicate should parse without corruption: {:?}", + result + ); +} diff --git a/nodedb/tests/sql_trigger_fuel.rs b/nodedb/tests/sql_trigger_fuel.rs new file mode 100644 index 00000000..0278a38a --- /dev/null +++ b/nodedb/tests/sql_trigger_fuel.rs @@ -0,0 +1,98 @@ +//! Integration coverage for trigger execution fuel / budget limits. +//! +//! Trigger bodies run user-supplied code. An infinite loop must not pin a +//! Control Plane worker for the full 3600-second deadline. The execution +//! budget must cap iterations and wall-clock time to a reasonable bound. + +mod common; + +use std::time::{Duration, Instant}; + +use common::pgwire_harness::TestServer; + +/// An infinite-loop trigger body must be terminated by the execution budget +/// within a reasonable time (< 30 seconds), not the 1-hour wall clock. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn infinite_loop_trigger_terminates_within_budget() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION fuel_test TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + v INT)", + ) + .await + .unwrap(); + + // Create a trigger with an infinite loop body. + server + .exec( + "CREATE TRIGGER fuel_loop AFTER INSERT ON fuel_test FOR EACH ROW \ + BEGIN LOOP END LOOP; END", + ) + .await + .unwrap(); + + let start = Instant::now(); + let result = server + .exec("INSERT INTO fuel_test (id, v) VALUES ('a', 1)") + .await; + let elapsed = start.elapsed(); + + // The trigger should either: + // (a) error because the budget was exhausted, or + // (b) succeed but complete within a sane time bound. + // It must NOT hang for 3600 seconds. + assert!( + elapsed < Duration::from_secs(30), + "infinite-loop trigger took {elapsed:?} — budget should cap execution well under 1 hour" + ); + + if let Err(msg) = result { + assert!( + msg.to_lowercase().contains("fuel") + || msg.to_lowercase().contains("budget") + || msg.to_lowercase().contains("timeout") + || msg.to_lowercase().contains("iteration") + || msg.to_lowercase().contains("limit"), + "error should mention budget/fuel/timeout: {msg}" + ); + } +} + +/// A trigger with a bounded but expensive loop should succeed if within budget. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn bounded_loop_trigger_completes_normally() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION fuel_ok TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + counter INT DEFAULT 0)", + ) + .await + .unwrap(); + + // A trigger that loops 10 times — well within any reasonable budget. + server + .exec( + "CREATE TRIGGER fuel_bounded AFTER INSERT ON fuel_ok FOR EACH ROW \ + BEGIN \ + DECLARE i INT := 0; \ + WHILE i < 10 LOOP \ + i := i + 1; \ + END LOOP; \ + END", + ) + .await + .unwrap(); + + let result = server.exec("INSERT INTO fuel_ok (id) VALUES ('b1')").await; + assert!( + result.is_ok(), + "bounded loop trigger should complete normally: {:?}", + result + ); +} diff --git a/nodedb/tests/sql_utf8_expressions.rs b/nodedb/tests/sql_utf8_expressions.rs new file mode 100644 index 00000000..9c7bd08f --- /dev/null +++ b/nodedb/tests/sql_utf8_expressions.rs @@ -0,0 +1,145 @@ +//! Integration coverage for UTF-8 correctness in expression parsing. +//! +//! The expression tokenizer in `nodedb-query` iterates byte-by-byte but +//! slices `&str` — panicking on multi-byte UTF-8 codepoints. These tests +//! verify that non-ASCII characters in generated columns, check constraints, +//! and typeguard expressions do not cause panics or data corruption. + +mod common; + +use common::pgwire_harness::TestServer; + +/// A GENERATED ALWAYS AS expression containing a CJK literal must not panic. +/// The tokenizer slices `&input[i..i+2]` which crosses a char boundary on +/// 3-byte UTF-8 codepoints. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn generated_column_with_cjk_literal_does_not_panic() { + let server = TestServer::start().await; + + let result = server + .exec( + "CREATE COLLECTION utf_gen TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + x TEXT, \ + y TEXT GENERATED ALWAYS AS ('你' || x))", + ) + .await; + assert!( + result.is_ok(), + "CJK literal in GENERATED expression should not panic: {:?}", + result + ); + + server + .exec("INSERT INTO utf_gen (id, x) VALUES ('u1', 'hello')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT y FROM utf_gen WHERE id = 'u1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1); + assert!( + rows[0].contains("你hello"), + "generated column should concatenate CJK prefix: got {:?}", + rows[0] + ); +} + +/// Emoji (4-byte UTF-8) in an expression must not panic. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn generated_column_with_emoji_does_not_panic() { + let server = TestServer::start().await; + + let result = server + .exec( + "CREATE COLLECTION utf_emoji TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + tag TEXT, \ + display TEXT GENERATED ALWAYS AS ('🎉' || tag))", + ) + .await; + assert!( + result.is_ok(), + "emoji in GENERATED expression should not panic: {:?}", + result + ); +} + +/// Multi-byte characters followed by operators (`<=`, `>=`, `!=`, `||`) +/// are the specific trigger for the `&input[i..i+2]` boundary panic. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn check_constraint_with_utf8_near_operator() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION utf_ck TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + name TEXT)", + ) + .await + .unwrap(); + + // Check constraint with a multi-byte literal adjacent to `!=` + let result = server + .exec( + "ALTER COLLECTION utf_ck ADD CONSTRAINT no_forbidden \ + CHECK (NEW.name != '禁止')", + ) + .await; + assert!( + result.is_ok(), + "CHECK with UTF-8 literal near != operator should not panic: {:?}", + result + ); + + // Valid insert should pass the constraint. + server + .exec("INSERT INTO utf_ck (id, name) VALUES ('c1', 'allowed')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT id FROM utf_ck WHERE id = 'c1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1); +} + +/// Latin diacritics (2-byte UTF-8) in expression strings. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn generated_column_with_latin_diacritics() { + let server = TestServer::start().await; + + let result = server + .exec( + "CREATE COLLECTION utf_lat TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + city TEXT, \ + greeting TEXT GENERATED ALWAYS AS ('café in ' || city))", + ) + .await; + assert!( + result.is_ok(), + "Latin diacritics in GENERATED expression should not panic: {:?}", + result + ); + + server + .exec("INSERT INTO utf_lat (id, city) VALUES ('l1', 'Paris')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT greeting FROM utf_lat WHERE id = 'l1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1); + assert!( + rows[0].contains("café in Paris"), + "generated column should preserve diacritics: got {:?}", + rows[0] + ); +} From e8fe4bc922a07441ab6e0de3a709be27702e31b3 Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Thu, 16 Apr 2026 22:01:57 +0800 Subject: [PATCH 8/9] fix(types): use sort_by_key per clippy unnecessary_sort_by --- nodedb-types/src/approx/spacesaving.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodedb-types/src/approx/spacesaving.rs b/nodedb-types/src/approx/spacesaving.rs index 54ee07b0..d375ec36 100644 --- a/nodedb-types/src/approx/spacesaving.rs +++ b/nodedb-types/src/approx/spacesaving.rs @@ -55,7 +55,7 @@ impl SpaceSaving { .iter() .map(|(&item, &(count, error))| (item, count, error)) .collect(); - result.sort_by(|a, b| b.1.cmp(&a.1)); + result.sort_by_key(|item| std::cmp::Reverse(item.1)); result } From 96cbd80507c56e174b7ae433e3be6cd1c8c122e4 Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Fri, 17 Apr 2026 02:59:08 +0800 Subject: [PATCH 9/9] fix: resolve Rust 1.95 clippy lints (collapsible_match, sort_by_key, while_let, checked_div) --- nodedb-sql/src/planner/dml.rs | 10 +++--- nodedb-sql/src/planner/select.rs | 35 +++++++++---------- .../procedural/executor/core/dispatch.rs | 5 +-- .../sql_plan_convert/value/time_range.rs | 16 ++++----- .../control/security/catalog/checkpoints.rs | 2 +- nodedb/src/control/security/explain.rs | 13 ++++--- nodedb/src/control/server/native/session.rs | 15 +++----- nodedb/src/control/server/resp/handler.rs | 32 ++++++++--------- nodedb/src/control/server/sync/listener.rs | 4 +-- .../handlers/document/text_extract.rs | 6 ++-- nodedb/src/data/executor/handlers/facet.rs | 2 +- nodedb/src/engine/graph/algo/kcore.rs | 2 +- nodedb/src/engine/sparse/btree_scan.rs | 2 +- .../src/engine/timeseries/compound_filter.rs | 2 +- 14 files changed, 64 insertions(+), 82 deletions(-) diff --git a/nodedb-sql/src/planner/dml.rs b/nodedb-sql/src/planner/dml.rs index 829ca50d..7d8af330 100644 --- a/nodedb-sql/src/planner/dml.rs +++ b/nodedb-sql/src/planner/dml.rs @@ -548,12 +548,10 @@ fn collect_pk_equalities(expr: &ast::Expr, pk: &str, keys: &mut Vec) { expr: inner, list, negated: false, - } => { - if is_column(inner, pk) { - for item in list { - if let Ok(v) = expr_to_sql_value(item) { - keys.push(v); - } + } if is_column(inner, pk) => { + for item in list { + if let Ok(v) = expr_to_sql_value(item) { + keys.push(v); } } } diff --git a/nodedb-sql/src/planner/select.rs b/nodedb-sql/src/planner/select.rs index 06bec366..ad7f526a 100644 --- a/nodedb-sql/src/planner/select.rs +++ b/nodedb-sql/src/planner/select.rs @@ -572,25 +572,24 @@ fn try_extract_sort_search( }, })); } - SearchTrigger::TextSearch => { - if args.len() >= 2 { - let query_text = extract_string_literal(&args[1])?; - let limit = match plan { - SqlPlan::Scan { limit, .. } => limit.unwrap_or(10), - _ => 10, - }; - return Ok(Some(SqlPlan::TextSearch { - collection, - query: query_text, - top_k: limit, - fuzzy: true, - filters: match plan { - SqlPlan::Scan { filters, .. } => filters.clone(), - _ => Vec::new(), - }, - })); - } + SearchTrigger::TextSearch if args.len() >= 2 => { + let query_text = extract_string_literal(&args[1])?; + let limit = match plan { + SqlPlan::Scan { limit, .. } => limit.unwrap_or(10), + _ => 10, + }; + return Ok(Some(SqlPlan::TextSearch { + collection, + query: query_text, + top_k: limit, + fuzzy: true, + filters: match plan { + SqlPlan::Scan { filters, .. } => filters.clone(), + _ => Vec::new(), + }, + })); } + SearchTrigger::TextSearch => {} SearchTrigger::HybridSearch => { return plan_hybrid_from_sort(&args, &collection, plan, functions); } diff --git a/nodedb/src/control/planner/procedural/executor/core/dispatch.rs b/nodedb/src/control/planner/procedural/executor/core/dispatch.rs index c51ad949..a3929c39 100644 --- a/nodedb/src/control/planner/procedural/executor/core/dispatch.rs +++ b/nodedb/src/control/planner/procedural/executor/core/dispatch.rs @@ -196,10 +196,7 @@ fn fold_literal_string_concat(sql: &str) -> String { }; let mut folded = false; - loop { - let Some(op_end) = consume_string_concat_operator(bytes, cursor) else { - break; - }; + while let Some(op_end) = consume_string_concat_operator(bytes, cursor) { let next_lit = skip_ascii_whitespace(bytes, op_end); let Some((next_cursor, next_literal)) = parse_single_quoted_literal(sql, next_lit) else { diff --git a/nodedb/src/control/planner/sql_plan_convert/value/time_range.rs b/nodedb/src/control/planner/sql_plan_convert/value/time_range.rs index e63a8739..309ded76 100644 --- a/nodedb/src/control/planner/sql_plan_convert/value/time_range.rs +++ b/nodedb/src/control/planner/sql_plan_convert/value/time_range.rs @@ -23,15 +23,15 @@ fn extract_time_bounds_from_filter(expr: &FilterExpr, min_ts: &mut i64, max_ts: FilterExpr::Comparison { field, op, value } if is_time_field(field) => { if let Some(ms) = sql_value_to_timestamp_ms(value) { match op { - nodedb_sql::types::CompareOp::Ge | nodedb_sql::types::CompareOp::Gt => { - if ms > *min_ts { - *min_ts = ms; - } + nodedb_sql::types::CompareOp::Ge | nodedb_sql::types::CompareOp::Gt + if ms > *min_ts => + { + *min_ts = ms; } - nodedb_sql::types::CompareOp::Le | nodedb_sql::types::CompareOp::Lt => { - if ms < *max_ts { - *max_ts = ms; - } + nodedb_sql::types::CompareOp::Le | nodedb_sql::types::CompareOp::Lt + if ms < *max_ts => + { + *max_ts = ms; } nodedb_sql::types::CompareOp::Eq => { *min_ts = ms; diff --git a/nodedb/src/control/security/catalog/checkpoints.rs b/nodedb/src/control/security/catalog/checkpoints.rs index 0c295378..f6d915af 100644 --- a/nodedb/src/control/security/catalog/checkpoints.rs +++ b/nodedb/src/control/security/catalog/checkpoints.rs @@ -110,7 +110,7 @@ impl SystemCatalog { } // Sort by created_at descending (most recent first). - records.sort_by(|a, b| b.created_at.cmp(&a.created_at)); + records.sort_by_key(|r| std::cmp::Reverse(r.created_at)); if records.len() > limit && limit > 0 { records.truncate(limit); diff --git a/nodedb/src/control/security/explain.rs b/nodedb/src/control/security/explain.rs index 24ff8d0d..a5ccb90d 100644 --- a/nodedb/src/control/security/explain.rs +++ b/nodedb/src/control/security/explain.rs @@ -166,13 +166,12 @@ pub fn lint_predicate(predicate: &super::predicate::RlsPredicate) -> Vec super::predicate::RlsPredicate::AlwaysFalse => { warnings.push("contradiction: predicate is always false (blocks everything)".into()); } - super::predicate::RlsPredicate::Compare { value, .. } => { - if !value.is_auth_ref() && matches!(value, super::predicate::PredicateValue::Literal(_)) - { - warnings.push( - "static predicate: no $auth reference — same result for all users".into(), - ); - } + super::predicate::RlsPredicate::Compare { value, .. } + if !value.is_auth_ref() + && matches!(value, super::predicate::PredicateValue::Literal(_)) => + { + warnings + .push("static predicate: no $auth reference — same result for all users".into()); } super::predicate::RlsPredicate::And(children) | super::predicate::RlsPredicate::Or(children) => { diff --git a/nodedb/src/control/server/native/session.rs b/nodedb/src/control/server/native/session.rs index d143859d..105945eb 100644 --- a/nodedb/src/control/server/native/session.rs +++ b/nodedb/src/control/server/native/session.rs @@ -418,17 +418,12 @@ fn chunk_large_response( }; let sample_bytes = codec::encode_response(&sample_resp, format)?; let sample_count = total_rows.min(100); - let per_row_estimate = if sample_count > 0 { - sample_bytes.len() / sample_count - } else { - 256 // fallback - }; + let per_row_estimate = sample_bytes.len().checked_div(sample_count).unwrap_or(256); - let rows_per_chunk = if per_row_estimate > 0 { - (target_size / per_row_estimate).max(1) - } else { - 1000 - }; + let rows_per_chunk = target_size + .checked_div(per_row_estimate) + .map(|v| v.max(1)) + .unwrap_or(1000); let mut frames = Vec::new(); let chunks: Vec<_> = rows.chunks(rows_per_chunk).collect(); diff --git a/nodedb/src/control/server/resp/handler.rs b/nodedb/src/control/server/resp/handler.rs index 121e19cb..25d7fe9b 100644 --- a/nodedb/src/control/server/resp/handler.rs +++ b/nodedb/src/control/server/resp/handler.rs @@ -259,27 +259,23 @@ async fn handle_scan(cmd: &RespCommand, session: &RespSession, state: &SharedSta i += 2; } // NodeDB extension: SCAN 0 FILTER = - Some(ref flag) if flag == "FILTER" => { + Some(ref flag) if flag == "FILTER" && i + 4 <= cmd.argc() => { // Parse simple "field = value" predicate (needs 4 args: FILTER field = value). - if i + 4 <= cmd.argc() { - let field = cmd.arg_str(i + 1).unwrap_or(""); - let _op = cmd.arg_str(i + 2).unwrap_or(""); // "=" expected - let value = cmd.arg_str(i + 3).unwrap_or(""); - let scan_filter = serde_json::json!([{ - "field": field, - "op": "eq", - "value": value, - }]); - match nodedb_types::json_to_msgpack(&scan_filter) { - Ok(bytes) => filter_bytes = bytes, - Err(_) => { - return RespValue::err("ERR filter serialization failed"); - } + let field = cmd.arg_str(i + 1).unwrap_or(""); + let _op = cmd.arg_str(i + 2).unwrap_or(""); // "=" expected + let value = cmd.arg_str(i + 3).unwrap_or(""); + let scan_filter = serde_json::json!([{ + "field": field, + "op": "eq", + "value": value, + }]); + match nodedb_types::json_to_msgpack(&scan_filter) { + Ok(bytes) => filter_bytes = bytes, + Err(_) => { + return RespValue::err("ERR filter serialization failed"); } - i += 4; - } else { - i += 1; } + i += 4; } _ => { i += 1; diff --git a/nodedb/src/control/server/sync/listener.rs b/nodedb/src/control/server/sync/listener.rs index a974844b..daa451ab 100644 --- a/nodedb/src/control/server/sync/listener.rs +++ b/nodedb/src/control/server/sync/listener.rs @@ -316,9 +316,9 @@ async fn handle_sync_session( } } Ok(Message::Ping(data)) => { - if ws.send(Message::Pong(data)).await.is_err() { + let Ok(_) = ws.send(Message::Pong(data)).await else { break; - } + }; } Ok(Message::Close(_)) => break, Err(e) => { diff --git a/nodedb/src/data/executor/handlers/document/text_extract.rs b/nodedb/src/data/executor/handlers/document/text_extract.rs index 9a4ba722..2646771b 100644 --- a/nodedb/src/data/executor/handlers/document/text_extract.rs +++ b/nodedb/src/data/executor/handlers/document/text_extract.rs @@ -14,10 +14,8 @@ pub fn extract_indexable_text(doc: &serde_json::Value) -> String { pub(super) fn collect_text(val: &serde_json::Value, parts: &mut Vec) { match val { - serde_json::Value::String(s) => { - if !s.is_empty() { - parts.push(s.clone()); - } + serde_json::Value::String(s) if !s.is_empty() => { + parts.push(s.clone()); } serde_json::Value::Object(map) => { for v in map.values() { diff --git a/nodedb/src/data/executor/handlers/facet.rs b/nodedb/src/data/executor/handlers/facet.rs index d5a09d78..c5ef5eae 100644 --- a/nodedb/src/data/executor/handlers/facet.rs +++ b/nodedb/src/data/executor/handlers/facet.rs @@ -151,7 +151,7 @@ impl CoreLoop { } let mut result: Vec<(String, usize)> = counts.into_iter().collect(); - result.sort_by(|a, b| b.1.cmp(&a.1)); // Count descending. + result.sort_by_key(|r| std::cmp::Reverse(r.1)); // Count descending. result } } diff --git a/nodedb/src/engine/graph/algo/kcore.rs b/nodedb/src/engine/graph/algo/kcore.rs index 8c67420f..07d1a79c 100644 --- a/nodedb/src/engine/graph/algo/kcore.rs +++ b/nodedb/src/engine/graph/algo/kcore.rs @@ -91,7 +91,7 @@ pub fn run(csr: &CsrIndex) -> AlgoResultBatch { // Build result sorted by coreness descending. let mut scored: Vec<(usize, usize)> = coreness.into_iter().enumerate().collect(); - scored.sort_by(|a, b| b.1.cmp(&a.1)); + scored.sort_by_key(|&(_, k)| std::cmp::Reverse(k)); let mut batch = AlgoResultBatch::new(GraphAlgorithm::KCore); for (node, k) in scored { diff --git a/nodedb/src/engine/sparse/btree_scan.rs b/nodedb/src/engine/sparse/btree_scan.rs index ad48316f..ea870dbd 100644 --- a/nodedb/src/engine/sparse/btree_scan.rs +++ b/nodedb/src/engine/sparse/btree_scan.rs @@ -190,7 +190,7 @@ impl SparseEngine { } let mut result: Vec<(String, usize)> = groups.into_iter().collect(); - result.sort_by(|a, b| b.1.cmp(&a.1)); // Sort by count descending (most popular first). + result.sort_by_key(|r| std::cmp::Reverse(r.1)); // Sort by count descending (most popular first). debug!( collection, field, diff --git a/nodedb/src/engine/timeseries/compound_filter.rs b/nodedb/src/engine/timeseries/compound_filter.rs index c22f284a..379a84a1 100644 --- a/nodedb/src/engine/timeseries/compound_filter.rs +++ b/nodedb/src/engine/timeseries/compound_filter.rs @@ -93,7 +93,7 @@ impl CompoundTagIndex { .iter() .map(|(k, &v)| (k.clone(), v)) .collect(); - combos.sort_by(|a, b| b.1.cmp(&a.1)); + combos.sort_by_key(|c| std::cmp::Reverse(c.1)); combos.truncate(n); combos }