diff --git a/nodedb-query/src/expr_parse.rs b/nodedb-query/src/expr_parse/mod.rs similarity index 78% rename from nodedb-query/src/expr_parse.rs rename to nodedb-query/src/expr_parse/mod.rs index 4c4f4762..fe3fb726 100644 --- a/nodedb-query/src/expr_parse.rs +++ b/nodedb-query/src/expr_parse/mod.rs @@ -16,8 +16,11 @@ //! //! Determinism validation: rejects `NOW()`, `RANDOM()`, `NEXTVAL()`, `UUID()`. +mod tokenizer; + use super::expr::{BinaryOp, SqlExpr}; use nodedb_types::Value; +use tokenizer::{Token, TokenKind, tokenize}; /// Parse a SQL expression string into an SqlExpr AST. /// @@ -46,160 +49,11 @@ pub fn parse_generated_expr(text: &str) -> Result<(SqlExpr, Vec), String Ok((expr, deps)) } -// ── Tokenizer ───────────────────────────────────────────────────────── - -#[derive(Debug, Clone)] -struct Token { - text: String, - kind: TokenKind, -} - -#[derive(Debug, Clone, Copy, PartialEq)] -enum TokenKind { - Ident, - Number, - StringLit, - LParen, - RParen, - Comma, - Op, -} - -fn tokenize(input: &str) -> Result, String> { - let bytes = input.as_bytes(); - let mut tokens = Vec::new(); - let mut i = 0; - - while i < bytes.len() { - let b = bytes[i]; - - // Skip whitespace. - if b.is_ascii_whitespace() { - i += 1; - continue; - } - - // Single-char tokens. - if b == b'(' { - tokens.push(Token { - text: "(".into(), - kind: TokenKind::LParen, - }); - i += 1; - continue; - } - if b == b')' { - tokens.push(Token { - text: ")".into(), - kind: TokenKind::RParen, - }); - i += 1; - continue; - } - if b == b',' { - tokens.push(Token { - text: ",".into(), - kind: TokenKind::Comma, - }); - i += 1; - continue; - } - - // Two-char operators. - if i + 1 < bytes.len() { - let two = &input[i..i + 2]; - if matches!(two, "<=" | ">=" | "!=" | "<>") { - tokens.push(Token { - text: two.into(), - kind: TokenKind::Op, - }); - i += 2; - continue; - } - if two == "||" { - tokens.push(Token { - text: "||".into(), - kind: TokenKind::Op, - }); - i += 2; - continue; - } - } - - // Single-char operators. - if matches!(b, b'+' | b'-' | b'*' | b'/' | b'%' | b'=' | b'<' | b'>') { - tokens.push(Token { - text: (b as char).to_string(), - kind: TokenKind::Op, - }); - i += 1; - continue; - } - - // String literal. - if b == b'\'' { - let mut s = String::new(); - i += 1; - while i < bytes.len() { - if bytes[i] == b'\'' { - if i + 1 < bytes.len() && bytes[i + 1] == b'\'' { - s.push('\''); - i += 2; - continue; - } - i += 1; - break; - } - s.push(bytes[i] as char); - i += 1; - } - tokens.push(Token { - text: s, - kind: TokenKind::StringLit, - }); - continue; - } - - // Number. - if b.is_ascii_digit() || (b == b'.' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit()) - { - let start = i; - while i < bytes.len() && (bytes[i].is_ascii_digit() || bytes[i] == b'.') { - i += 1; - } - tokens.push(Token { - text: input[start..i].to_string(), - kind: TokenKind::Number, - }); - continue; - } - - // Identifier or keyword. - if b.is_ascii_alphabetic() || b == b'_' { - let start = i; - while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') { - i += 1; - } - tokens.push(Token { - text: input[start..i].to_string(), - kind: TokenKind::Ident, - }); - continue; - } - - return Err(format!("unexpected character: '{}'", b as char)); - } - - Ok(tokens) -} - // ── Recursive descent parser ────────────────────────────────────────── /// Maximum recursion depth for nested parentheses / sub-expressions. -/// Exceeding this limit returns `Err` instead of overflowing the stack. const MAX_EXPR_DEPTH: usize = 128; -/// Parse an expression (lowest precedence: OR). fn parse_expr(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result { parse_or(tokens, pos, depth) } @@ -304,13 +158,11 @@ fn parse_multiplicative( } fn parse_unary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result { - // Unary minus. if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op && tokens[*pos].text == "-" { *pos += 1; let expr = parse_primary(tokens, pos, depth)?; return Ok(SqlExpr::Negate(Box::new(expr))); } - // NOT if peek_keyword(tokens, *pos, "NOT") { *pos += 1; let expr = parse_primary(tokens, pos, depth)?; @@ -327,7 +179,6 @@ fn parse_primary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result let token = &tokens[*pos]; match token.kind { - // Parenthesized expression. TokenKind::LParen => { *depth += 1; if *depth > MAX_EXPR_DEPTH { @@ -342,7 +193,6 @@ fn parse_primary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result Ok(expr) } - // Number literal. TokenKind::Number => { *pos += 1; if let Ok(i) = token.text.parse::() { @@ -354,13 +204,11 @@ fn parse_primary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result } } - // String literal. TokenKind::StringLit => { *pos += 1; Ok(SqlExpr::Literal(Value::String(token.text.clone()))) } - // Identifier: column ref, function call, keyword (NULL, TRUE, FALSE, CASE, COALESCE). TokenKind::Ident => { let name = token.text.clone(); let upper = name.to_uppercase(); @@ -376,7 +224,6 @@ fn parse_primary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result Ok(SqlExpr::Coalesce(args)) } _ => { - // Function call: IDENT(args). if *pos < tokens.len() && tokens[*pos].kind == TokenKind::LParen { let args = parse_arg_list(tokens, pos, depth)?; Ok(SqlExpr::Function { @@ -384,7 +231,6 @@ fn parse_primary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result args, }) } else { - // Column reference. Ok(SqlExpr::Column(name.to_lowercase())) } } @@ -395,7 +241,6 @@ fn parse_primary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result } } -/// Parse `CASE WHEN cond THEN result [WHEN ... THEN ...] [ELSE result] END`. fn parse_case(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result { let mut when_thens = Vec::new(); let mut else_expr = None; @@ -429,7 +274,6 @@ fn parse_case(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result Result, String> { + let chars: Vec<(usize, char)> = input.char_indices().collect(); + let mut tokens = Vec::new(); + let mut i = 0; + + while i < chars.len() { + let (_, ch) = chars[i]; + + // Skip whitespace. + if ch.is_ascii_whitespace() { + i += 1; + continue; + } + + // Single-char structural tokens. + if ch == '(' { + tokens.push(Token { + text: "(".into(), + kind: TokenKind::LParen, + }); + i += 1; + continue; + } + if ch == ')' { + tokens.push(Token { + text: ")".into(), + kind: TokenKind::RParen, + }); + i += 1; + continue; + } + if ch == ',' { + tokens.push(Token { + text: ",".into(), + kind: TokenKind::Comma, + }); + i += 1; + continue; + } + + // Two-char operators: <=, >=, !=, <>, || + if i + 1 < chars.len() { + let (_, next_ch) = chars[i + 1]; + let two: String = [ch, next_ch].iter().collect(); + if matches!(two.as_str(), "<=" | ">=" | "!=" | "<>" | "||") { + tokens.push(Token { + text: two, + kind: TokenKind::Op, + }); + i += 2; + continue; + } + } + + // Single-char operators. + if matches!(ch, '+' | '-' | '*' | '/' | '%' | '=' | '<' | '>') { + tokens.push(Token { + text: ch.to_string(), + kind: TokenKind::Op, + }); + i += 1; + continue; + } + + // String literal (single-quoted). + if ch == '\'' { + let mut s = String::new(); + i += 1; + while i < chars.len() { + let (_, c) = chars[i]; + if c == '\'' { + // Check for '' escape. + if i + 1 < chars.len() && chars[i + 1].1 == '\'' { + s.push('\''); + i += 2; + continue; + } + i += 1; + break; + } + s.push(c); + i += 1; + } + tokens.push(Token { + text: s, + kind: TokenKind::StringLit, + }); + continue; + } + + // Number. + if ch.is_ascii_digit() + || (ch == '.' && i + 1 < chars.len() && chars[i + 1].1.is_ascii_digit()) + { + let start_byte = chars[i].0; + let start_i = i; + while i < chars.len() && (chars[i].1.is_ascii_digit() || chars[i].1 == '.') { + i += 1; + } + let end_byte = if i < chars.len() { + chars[i].0 + } else { + input.len() + }; + tokens.push(Token { + text: input[start_byte..end_byte].to_string(), + kind: TokenKind::Number, + }); + let _ = start_i; // suppress unused warning + continue; + } + + // Identifier or keyword (ASCII letters, digits, underscore). + if ch.is_ascii_alphabetic() || ch == '_' { + let start_byte = chars[i].0; + while i < chars.len() && (chars[i].1.is_ascii_alphanumeric() || chars[i].1 == '_') { + i += 1; + } + let end_byte = if i < chars.len() { + chars[i].0 + } else { + input.len() + }; + tokens.push(Token { + text: input[start_byte..end_byte].to_string(), + kind: TokenKind::Ident, + }); + continue; + } + + // Non-ASCII characters in an unquoted context: skip gracefully. + // This can happen with stray Unicode in expressions; previously this + // caused a panic. We emit an error with the character for diagnostics. + return Err(format!( + "unexpected character: '{ch}' (U+{:04X})", + ch as u32 + )); + } + + Ok(tokens) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn ascii_expression() { + let tokens = tokenize("price * (1 + tax_rate)").unwrap(); + assert_eq!(tokens.len(), 7); + } + + #[test] + fn cjk_string_literal() { + let tokens = tokenize("'你好' || name").unwrap(); + assert_eq!(tokens.len(), 3); + assert_eq!(tokens[0].kind, TokenKind::StringLit); + assert_eq!(tokens[0].text, "你好"); + } + + #[test] + fn emoji_string_literal() { + let tokens = tokenize("'🎉' || tag").unwrap(); + assert_eq!(tokens.len(), 3); + assert_eq!(tokens[0].text, "🎉"); + } + + #[test] + fn two_char_op_after_multibyte_string() { + // Previously panicked: &input[i..i+2] crossed a char boundary. + let tokens = tokenize("'你' || x").unwrap(); + assert_eq!(tokens.len(), 3); + assert_eq!(tokens[1].text, "||"); + } + + #[test] + fn escaped_quote_in_string() { + let tokens = tokenize("'it''s'").unwrap(); + assert_eq!(tokens.len(), 1); + assert_eq!(tokens[0].text, "it's"); + } + + #[test] + fn latin_diacritics_in_string() { + let tokens = tokenize("'café'").unwrap(); + assert_eq!(tokens[0].text, "café"); + } + + #[test] + fn comparison_after_cjk() { + let tokens = tokenize("name != '禁止'").unwrap(); + assert_eq!(tokens.len(), 3); + assert_eq!(tokens[1].text, "!="); + assert_eq!(tokens[2].text, "禁止"); + } +} diff --git a/nodedb-sql/src/lib.rs b/nodedb-sql/src/lib.rs index 4b614fe2..8ff3be90 100644 --- a/nodedb-sql/src/lib.rs +++ b/nodedb-sql/src/lib.rs @@ -25,6 +25,28 @@ pub use error::{Result, SqlError}; pub use params::ParamValue; pub use types::*; +/// Parse a standalone SQL expression string into an `SqlExpr`. +/// +/// Used by the DEFAULT expression evaluator to handle arbitrary expressions +/// (e.g. `upper('x')`, `1 + 2`) that don't match the hard-coded keyword list. +pub fn parse_expr_string(expr_text: &str) -> Result { + use sqlparser::dialect::GenericDialect; + use sqlparser::parser::Parser; + + let dialect = GenericDialect {}; + let ast_expr = Parser::new(&dialect) + .try_with_sql(expr_text) + .map_err(|e| SqlError::Parse { + detail: e.to_string(), + })? + .parse_expr() + .map_err(|e| SqlError::Parse { + detail: e.to_string(), + })?; + + resolver::expr::convert_expr(&ast_expr) +} + use functions::registry::FunctionRegistry; use parser::preprocess; use parser::statement::{StatementKind, classify, parse_sql}; diff --git a/nodedb-sql/src/parser/preprocess.rs b/nodedb-sql/src/parser/preprocess.rs index d8f2d7ae..2309e1fd 100644 --- a/nodedb-sql/src/parser/preprocess.rs +++ b/nodedb-sql/src/parser/preprocess.rs @@ -270,12 +270,14 @@ fn value_to_json(value: &nodedb_types::Value) -> String { fn find_matching_brace(chars: &[char], start: usize) -> Option { let mut depth = 0; let mut in_string = false; - for i in start..chars.len() { + let mut i = start; + while i < chars.len() { match chars[i] { '\'' if !in_string => in_string = true, '\'' if in_string => { if i + 1 < chars.len() && chars[i + 1] == '\'' { - // Skip escaped quote. + // Skip '' escape — advance past both quotes. + i += 2; continue; } in_string = false; @@ -289,6 +291,7 @@ fn find_matching_brace(chars: &[char], start: usize) -> Option { } _ => {} } + i += 1; } None } diff --git a/nodedb-sql/src/planner/const_fold.rs b/nodedb-sql/src/planner/const_fold.rs index 26a1fadb..d13f6028 100644 --- a/nodedb-sql/src/planner/const_fold.rs +++ b/nodedb-sql/src/planner/const_fold.rs @@ -64,9 +64,9 @@ pub fn fold_constant(expr: &SqlExpr, registry: &FunctionRegistry) -> Option Option { Some(match (l, op, r) { - (SqlValue::Int(a), BinaryOp::Add, SqlValue::Int(b)) => SqlValue::Int(a + b), - (SqlValue::Int(a), BinaryOp::Sub, SqlValue::Int(b)) => SqlValue::Int(a - b), - (SqlValue::Int(a), BinaryOp::Mul, SqlValue::Int(b)) => SqlValue::Int(a * b), + (SqlValue::Int(a), BinaryOp::Add, SqlValue::Int(b)) => SqlValue::Int(a.checked_add(b)?), + (SqlValue::Int(a), BinaryOp::Sub, SqlValue::Int(b)) => SqlValue::Int(a.checked_sub(b)?), + (SqlValue::Int(a), BinaryOp::Mul, SqlValue::Int(b)) => SqlValue::Int(a.checked_mul(b)?), (SqlValue::Float(a), BinaryOp::Add, SqlValue::Float(b)) => SqlValue::Float(a + b), (SqlValue::Float(a), BinaryOp::Sub, SqlValue::Float(b)) => SqlValue::Float(a - b), (SqlValue::Float(a), BinaryOp::Mul, SqlValue::Float(b)) => SqlValue::Float(a * b), diff --git a/nodedb-sql/src/planner/cte.rs b/nodedb-sql/src/planner/cte.rs index 06a994da..77ce72b3 100644 --- a/nodedb-sql/src/planner/cte.rs +++ b/nodedb-sql/src/planner/cte.rs @@ -4,9 +4,15 @@ use sqlparser::ast::{self, Query, SetExpr}; use crate::error::{Result, SqlError}; use crate::functions::registry::FunctionRegistry; +use crate::parser::normalize::{normalize_ident, normalize_object_name}; use crate::types::*; /// Plan a WITH RECURSIVE query. +/// +/// Supports table-based recursive CTEs where the base case scans a real +/// collection and the recursive step references both the collection and +/// the CTE. Value-generating CTEs (no underlying collection) return an +/// explicit unsupported error. pub fn plan_recursive_cte( query: &Query, catalog: &dyn SqlCatalog, @@ -16,41 +22,227 @@ pub fn plan_recursive_cte( detail: "expected WITH clause".into(), })?; - // Get the CTE definition. let cte = with.cte_tables.first().ok_or_else(|| SqlError::Parse { detail: "empty WITH clause".into(), })?; + let cte_name = normalize_ident(&cte.alias.name); + let cte_query = &cte.query; // The CTE body should be a UNION of base case and recursive case. - match &*cte_query.body { + let (left, right, set_quantifier) = match &*cte_query.body { SetExpr::SetOperation { op: ast::SetOperator::Union, left, right, - .. + set_quantifier, + } => (left, right, set_quantifier), + _ => { + return Err(SqlError::Unsupported { + detail: "WITH RECURSIVE requires UNION in CTE body".into(), + }); + } + }; + + // UNION ALL → distinct = false; UNION → distinct = true. + let distinct = !matches!(set_quantifier, ast::SetQuantifier::All); + + // Plan the base case (should not reference the CTE name). + let base = plan_cte_branch(left, catalog, functions)?; + + // Extract the source collection from the base case. + let collection = extract_collection(&base).unwrap_or_default(); + + // Plan the recursive branch. The recursive branch references the CTE + // name in its FROM clause — either directly (value-gen) or via a JOIN + // with a real table. We attempt to plan it; if it fails because the + // CTE name isn't in the catalog, we try to extract the real table from + // a JOIN and use it with the CTE self-reference as the recursive filter. + let (recursive_filters, join_link) = match plan_cte_branch(right, catalog, functions) { + Ok(plan) => (extract_filters(&plan), None), + Err(_) => { + // The recursive branch references the CTE name. Try to extract + // the real collection, filters, and join link from the AST. + extract_recursive_info(right, &cte_name)? + } + }; + + if collection.is_empty() { + return Err(SqlError::Unsupported { + detail: "WITH RECURSIVE requires a base case that scans a collection; \ + value-generating recursive CTEs are not yet supported" + .into(), + }); + } + + Ok(SqlPlan::RecursiveScan { + collection, + base_filters: extract_filters(&base), + recursive_filters, + join_link, + max_iterations: 100, + distinct, + limit: 10000, + }) +} + +/// Extract recursive info from the AST when normal planning fails +/// because the FROM clause references the CTE name. +/// +/// Returns `(filters, join_link)` where `join_link` is the +/// `(collection_field, working_table_field)` pair for the working-table +/// hash-join. +/// +/// Handles the common tree-traversal pattern: +/// `SELECT t.id FROM tree t INNER JOIN cte_name d ON t.parent_id = d.id` +/// → join_link = `("parent_id", "id")` +/// `(filters, join_link)` where `join_link` is `(collection_field, working_table_field)`. +type RecursiveInfo = (Vec, Option<(String, String)>); + +fn extract_recursive_info(expr: &SetExpr, cte_name: &str) -> Result { + let select = match expr { + SetExpr::Select(s) => s, + _ => { + return Err(SqlError::Unsupported { + detail: "recursive CTE branch must be SELECT".into(), + }); + } + }; + + let mut real_table_alias = None; + let mut cte_alias = None; + let mut join_on_expr = None; + + for from in &select.from { + let table_name = extract_table_name(&from.relation); + let table_alias = extract_table_alias(&from.relation); + + if let Some(name) = &table_name { + if name.eq_ignore_ascii_case(cte_name) { + cte_alias = table_alias.or_else(|| Some(name.clone())); + } else { + real_table_alias = table_alias.or_else(|| Some(name.clone())); + } + } + + for join in &from.joins { + let join_table = extract_table_name(&join.relation); + let join_alias = extract_table_alias(&join.relation); + if let Some(jt) = &join_table { + if jt.eq_ignore_ascii_case(cte_name) { + cte_alias = join_alias.or_else(|| Some(jt.clone())); + if let Some(cond) = extract_join_on_condition(&join.join_operator) { + join_on_expr = Some(cond.clone()); + } + } else { + real_table_alias = join_alias.or_else(|| Some(jt.clone())); + if join_on_expr.is_none() + && let Some(cond) = extract_join_on_condition(&join.join_operator) + { + join_on_expr = Some(cond.clone()); + } + } + } + } + } + + // Extract the join link from the ON condition. + let join_link = if let (Some(real_alias), Some(cte_al), Some(on_expr)) = + (&real_table_alias, &cte_alias, &join_on_expr) + { + extract_equi_link(on_expr, real_alias, cte_al) + } else { + None + }; + + // Convert the WHERE clause to filters if present. + let mut filters = Vec::new(); + if let Some(where_expr) = &select.selection { + let converted = crate::resolver::expr::convert_expr(where_expr)?; + filters.push(Filter { + expr: FilterExpr::Expr(converted), + }); + } + + Ok((filters, join_link)) +} + +/// Extract `(collection_field, cte_field)` from an equi-join ON clause. +/// +/// Given `t.parent_id = d.id` where `t` is the real table alias and `d` +/// is the CTE alias, returns `("parent_id", "id")`. +fn extract_equi_link( + expr: &ast::Expr, + real_alias: &str, + cte_alias: &str, +) -> Option<(String, String)> { + match expr { + ast::Expr::BinaryOp { + left, + op: ast::BinaryOperator::Eq, + right, } => { - let base = plan_cte_branch(left, catalog, functions)?; - let recursive = plan_cte_branch(right, catalog, functions)?; - - // Extract collection and filters from base/recursive plans. - let collection = extract_collection(&base) - .or_else(|| extract_collection(&recursive)) - .unwrap_or_default(); - - Ok(SqlPlan::RecursiveScan { - collection, - base_filters: extract_filters(&base), - recursive_filters: extract_filters(&recursive), - max_iterations: 100, - distinct: true, - limit: 10000, - }) + let left_parts = extract_qualified_column(left)?; + let right_parts = extract_qualified_column(right)?; + + // Determine which side is the real table and which is the CTE. + if left_parts.0.eq_ignore_ascii_case(real_alias) + && right_parts.0.eq_ignore_ascii_case(cte_alias) + { + Some((left_parts.1, right_parts.1)) + } else if right_parts.0.eq_ignore_ascii_case(real_alias) + && left_parts.0.eq_ignore_ascii_case(cte_alias) + { + Some((right_parts.1, left_parts.1)) + } else { + None + } } - _ => Err(SqlError::Unsupported { - detail: "WITH RECURSIVE requires UNION in CTE body".into(), - }), + // For AND-combined conditions, take the first equi-link found. + ast::Expr::BinaryOp { + left, + op: ast::BinaryOperator::And, + right, + } => extract_equi_link(left, real_alias, cte_alias) + .or_else(|| extract_equi_link(right, real_alias, cte_alias)), + _ => None, + } +} + +/// Extract `(table_or_alias, column)` from a qualified column reference. +fn extract_qualified_column(expr: &ast::Expr) -> Option<(String, String)> { + match expr { + ast::Expr::CompoundIdentifier(parts) if parts.len() == 2 => { + Some((normalize_ident(&parts[0]), normalize_ident(&parts[1]))) + } + _ => None, + } +} + +fn extract_table_name(relation: &ast::TableFactor) -> Option { + match relation { + ast::TableFactor::Table { name, .. } => Some(normalize_object_name(name)), + _ => None, + } +} + +fn extract_table_alias(relation: &ast::TableFactor) -> Option { + match relation { + ast::TableFactor::Table { alias, .. } => alias.as_ref().map(|a| normalize_ident(&a.name)), + _ => None, + } +} + +fn extract_join_on_condition(op: &ast::JoinOperator) -> Option<&ast::Expr> { + use ast::JoinOperator::*; + let constraint = match op { + Inner(c) | LeftOuter(c) | RightOuter(c) | FullOuter(c) => c, + _ => return None, + }; + match constraint { + ast::JoinConstraint::On(expr) => Some(expr), + _ => None, } } diff --git a/nodedb-sql/src/planner/dml.rs b/nodedb-sql/src/planner/dml.rs index 829ca50d..7d8af330 100644 --- a/nodedb-sql/src/planner/dml.rs +++ b/nodedb-sql/src/planner/dml.rs @@ -548,12 +548,10 @@ fn collect_pk_equalities(expr: &ast::Expr, pk: &str, keys: &mut Vec) { expr: inner, list, negated: false, - } => { - if is_column(inner, pk) { - for item in list { - if let Ok(v) = expr_to_sql_value(item) { - keys.push(v); - } + } if is_column(inner, pk) => { + for item in list { + if let Ok(v) = expr_to_sql_value(item) { + keys.push(v); } } } diff --git a/nodedb-sql/src/planner/join.rs b/nodedb-sql/src/planner/join.rs index a0f46a3a..225de587 100644 --- a/nodedb-sql/src/planner/join.rs +++ b/nodedb-sql/src/planner/join.rs @@ -204,7 +204,16 @@ fn extract_join_constraint(constraint: &ast::JoinConstraint) -> Result Result Ok((Vec::new(), None)), - ast::JoinConstraint::None => Ok((Vec::new(), None)), + ast::JoinConstraint::Natural => Err(crate::error::SqlError::Unsupported { + detail: "NATURAL JOIN is not supported; use explicit ON or USING clause".into(), + }), + ast::JoinConstraint::None => Err(crate::error::SqlError::Unsupported { + detail: "implicit cross join (no ON/USING clause) is not supported".into(), + }), } } diff --git a/nodedb-sql/src/planner/select.rs b/nodedb-sql/src/planner/select.rs index 06bec366..ad7f526a 100644 --- a/nodedb-sql/src/planner/select.rs +++ b/nodedb-sql/src/planner/select.rs @@ -572,25 +572,24 @@ fn try_extract_sort_search( }, })); } - SearchTrigger::TextSearch => { - if args.len() >= 2 { - let query_text = extract_string_literal(&args[1])?; - let limit = match plan { - SqlPlan::Scan { limit, .. } => limit.unwrap_or(10), - _ => 10, - }; - return Ok(Some(SqlPlan::TextSearch { - collection, - query: query_text, - top_k: limit, - fuzzy: true, - filters: match plan { - SqlPlan::Scan { filters, .. } => filters.clone(), - _ => Vec::new(), - }, - })); - } + SearchTrigger::TextSearch if args.len() >= 2 => { + let query_text = extract_string_literal(&args[1])?; + let limit = match plan { + SqlPlan::Scan { limit, .. } => limit.unwrap_or(10), + _ => 10, + }; + return Ok(Some(SqlPlan::TextSearch { + collection, + query: query_text, + top_k: limit, + fuzzy: true, + filters: match plan { + SqlPlan::Scan { filters, .. } => filters.clone(), + _ => Vec::new(), + }, + })); } + SearchTrigger::TextSearch => {} SearchTrigger::HybridSearch => { return plan_hybrid_from_sort(&args, &collection, plan, functions); } diff --git a/nodedb-sql/src/types.rs b/nodedb-sql/src/types.rs index 9718275f..31dd4367 100644 --- a/nodedb-sql/src/types.rs +++ b/nodedb-sql/src/types.rs @@ -197,6 +197,11 @@ pub enum SqlPlan { collection: String, base_filters: Vec, recursive_filters: Vec, + /// Equi-join link for tree-traversal recursion: + /// `(collection_field, working_table_field)`. + /// e.g. `("parent_id", "id")` means each iteration finds rows + /// where `collection.parent_id` matches a `working_table.id`. + join_link: Option<(String, String)>, max_iterations: usize, distinct: bool, limit: usize, diff --git a/nodedb-types/src/approx/spacesaving.rs b/nodedb-types/src/approx/spacesaving.rs index 54ee07b0..d375ec36 100644 --- a/nodedb-types/src/approx/spacesaving.rs +++ b/nodedb-types/src/approx/spacesaving.rs @@ -55,7 +55,7 @@ impl SpaceSaving { .iter() .map(|(&item, &(count, error))| (item, count, error)) .collect(); - result.sort_by(|a, b| b.1.cmp(&a.1)); + result.sort_by_key(|item| std::cmp::Reverse(item.1)); result } diff --git a/nodedb/src/bridge/physical_plan/query.rs b/nodedb/src/bridge/physical_plan/query.rs index 1a5122aa..7eb18ab4 100644 --- a/nodedb/src/bridge/physical_plan/query.rs +++ b/nodedb/src/bridge/physical_plan/query.rs @@ -187,6 +187,11 @@ pub enum QueryOp { base_filters: Vec, /// Recursive step filters (applied to working table each iteration). recursive_filters: Vec, + /// Equi-join link for tree-traversal recursion: + /// `(collection_field, working_table_field)`. + /// Each iteration finds rows where `collection_field` value + /// matches a `working_table_field` value from the previous iteration. + join_link: Option<(String, String)>, /// Maximum iterations to prevent infinite loops. Default: 100. max_iterations: usize, /// Whether to deduplicate results (UNION vs UNION ALL). diff --git a/nodedb/src/control/planner/procedural/executor/core/dispatch.rs b/nodedb/src/control/planner/procedural/executor/core/dispatch.rs index c51ad949..a3929c39 100644 --- a/nodedb/src/control/planner/procedural/executor/core/dispatch.rs +++ b/nodedb/src/control/planner/procedural/executor/core/dispatch.rs @@ -196,10 +196,7 @@ fn fold_literal_string_concat(sql: &str) -> String { }; let mut folded = false; - loop { - let Some(op_end) = consume_string_concat_operator(bytes, cursor) else { - break; - }; + while let Some(op_end) = consume_string_concat_operator(bytes, cursor) { let next_lit = skip_ascii_whitespace(bytes, op_end); let Some((next_cursor, next_literal)) = parse_single_quoted_literal(sql, next_lit) else { diff --git a/nodedb/src/control/planner/procedural/executor/core/mod.rs b/nodedb/src/control/planner/procedural/executor/core/mod.rs index 86ec06fa..dbfa9caa 100644 --- a/nodedb/src/control/planner/procedural/executor/core/mod.rs +++ b/nodedb/src/control/planner/procedural/executor/core/mod.rs @@ -100,7 +100,7 @@ impl<'a> StatementExecutor<'a> { block: &ProceduralBlock, bindings: &RowBindings, ) -> crate::Result<()> { - let mut budget = ExecutionBudget::unlimited(); + let mut budget = ExecutionBudget::trigger_default(); self.execute_block_with_exceptions( &block.statements, &block.exception_handlers, diff --git a/nodedb/src/control/planner/procedural/executor/eval.rs b/nodedb/src/control/planner/procedural/executor/eval.rs index ea0bf2c0..343be507 100644 --- a/nodedb/src/control/planner/procedural/executor/eval.rs +++ b/nodedb/src/control/planner/procedural/executor/eval.rs @@ -174,11 +174,11 @@ fn eval_binary_op( match (l, r) { (Value::Integer(a), Value::Integer(b)) => match op { - Plus => Some(Value::Integer(a + b)), - Minus => Some(Value::Integer(a - b)), - Multiply => Some(Value::Integer(a * b)), - Divide if *b != 0 => Some(Value::Integer(a / b)), - Modulo if *b != 0 => Some(Value::Integer(a % b)), + Plus => Some(Value::Integer(a.checked_add(*b)?)), + Minus => Some(Value::Integer(a.checked_sub(*b)?)), + Multiply => Some(Value::Integer(a.checked_mul(*b)?)), + Divide if *b != 0 => Some(Value::Integer(a.checked_div(*b)?)), + Modulo if *b != 0 => Some(Value::Integer(a.checked_rem(*b)?)), Gt => Some(Value::Bool(a > b)), GtEq => Some(Value::Bool(a >= b)), Lt => Some(Value::Bool(a < b)), @@ -191,7 +191,14 @@ fn eval_binary_op( Plus => Some(Value::Float(a + b)), Minus => Some(Value::Float(a - b)), Multiply => Some(Value::Float(a * b)), - Divide if *b != 0.0 => Some(Value::Float(a / b)), + Divide if *b != 0.0 && b.is_finite() && *b != -0.0 => { + let result = a / b; + if result.is_finite() { + Some(Value::Float(result)) + } else { + None + } + } Gt => Some(Value::Bool(a > b)), GtEq => Some(Value::Bool(a >= b)), Lt => Some(Value::Bool(a < b)), diff --git a/nodedb/src/control/planner/procedural/executor/fuel.rs b/nodedb/src/control/planner/procedural/executor/fuel.rs index 88ff3f6b..c2f770e8 100644 --- a/nodedb/src/control/planner/procedural/executor/fuel.rs +++ b/nodedb/src/control/planner/procedural/executor/fuel.rs @@ -32,14 +32,10 @@ impl ExecutionBudget { } } - /// Create an unlimited budget (for triggers with no explicit limits). - pub fn unlimited() -> Self { - Self { - fuel_remaining: u64::MAX, - deadline: Instant::now() + std::time::Duration::from_secs(3600), - max_iterations: u64::MAX, - timeout_secs: 3600, - } + /// Default budget for trigger bodies. Caps iterations and wall-clock + /// time to prevent runaway loops from pinning Control Plane workers. + pub fn trigger_default() -> Self { + Self::new(100_000, 10) } /// Consume one iteration of fuel. Returns error if exhausted. @@ -103,8 +99,8 @@ mod tests { } #[test] - fn unlimited() { - let mut budget = ExecutionBudget::unlimited(); + fn trigger_default_allows_bounded_loops() { + let mut budget = ExecutionBudget::trigger_default(); for _ in 0..10_000 { budget.consume_iteration().unwrap(); } diff --git a/nodedb/src/control/planner/procedural/executor/plan_cache.rs b/nodedb/src/control/planner/procedural/executor/plan_cache.rs index 680d4855..bc90c059 100644 --- a/nodedb/src/control/planner/procedural/executor/plan_cache.rs +++ b/nodedb/src/control/planner/procedural/executor/plan_cache.rs @@ -27,6 +27,7 @@ struct CacheInner { } struct CacheEntry { + body_sql: String, block: Arc, } @@ -58,9 +59,16 @@ impl ProcedureBlockCache { // is uncontended in the cache-hit path (early return). let mut inner = self.inner.lock().unwrap_or_else(|p| p.into_inner()); - // Cache hit — return immediately. + // Cache hit — verify the body matches (not just the hash) to guard + // against 64-bit hash collisions returning the wrong block. if let Some(entry) = inner.entries.get(&key) { - return Ok(Arc::clone(&entry.block)); + if entry.body_sql == body_sql { + return Ok(Arc::clone(&entry.block)); + } + // Hash collision with different body — evict the stale entry + // and fall through to re-parse. + inner.entries.remove(&key); + inner.order.retain(|k| *k != key); } // Cache miss — parse, cache, return. @@ -76,6 +84,7 @@ impl ProcedureBlockCache { inner.entries.insert( key, CacheEntry { + body_sql: body_sql.to_string(), block: Arc::clone(&arc), }, ); diff --git a/nodedb/src/control/planner/sql_plan_convert/convert.rs b/nodedb/src/control/planner/sql_plan_convert/convert.rs index 028c6729..48259035 100644 --- a/nodedb/src/control/planner/sql_plan_convert/convert.rs +++ b/nodedb/src/control/planner/sql_plan_convert/convert.rs @@ -134,15 +134,16 @@ pub(super) fn convert_one( right, on, join_type, + condition, limit, projection, filters, - .. } => super::scan::convert_join(super::scan_params::JoinPlanParams { left, right, on, join_type, + condition, limit, projection, filters, @@ -275,18 +276,20 @@ pub(super) fn convert_one( collection, base_filters, recursive_filters, + join_link, max_iterations, distinct, limit, - } => super::scan::convert_recursive_scan( + } => super::scan::convert_recursive_scan(super::scan_params::RecursiveScanParams { collection, base_filters, recursive_filters, + join_link, max_iterations, distinct, limit, tenant_id, - ), + }), SqlPlan::Cte { definitions, outer } => { super::set_ops::convert_cte(definitions, outer, tenant_id, ctx) diff --git a/nodedb/src/control/planner/sql_plan_convert/expr.rs b/nodedb/src/control/planner/sql_plan_convert/expr.rs index b2567f6e..d7e85877 100644 --- a/nodedb/src/control/planner/sql_plan_convert/expr.rs +++ b/nodedb/src/control/planner/sql_plan_convert/expr.rs @@ -6,13 +6,36 @@ use super::value::sql_value_to_nodedb_value; /// Convert a `nodedb_sql::types::SqlExpr` (parser AST) to a /// `nodedb_query::expr::SqlExpr` (bridge evaluation type). +/// +/// Column references use the **bare** name (no table qualifier) for +/// single-collection evaluation contexts (WHERE, CHECK, GENERATED). +/// For join contexts where the merged document uses qualified keys +/// (`"t1.col"`), use [`sql_expr_to_bridge_expr_qualified`] instead. pub(super) fn sql_expr_to_bridge_expr(expr: &SqlExpr) -> crate::bridge::expr_eval::SqlExpr { + convert_expr_inner(expr, false) +} + +/// Like [`sql_expr_to_bridge_expr`] but qualifies column references +/// with their table name (`t.col` → `"t.col"`) for join merged docs. +pub(super) fn sql_expr_to_bridge_expr_qualified( + expr: &SqlExpr, +) -> crate::bridge::expr_eval::SqlExpr { + convert_expr_inner(expr, true) +} + +fn convert_expr_inner(expr: &SqlExpr, qualify: bool) -> crate::bridge::expr_eval::SqlExpr { use crate::bridge::expr_eval::SqlExpr as BExpr; match expr { - SqlExpr::Column { name, .. } => BExpr::Column(name.clone()), + SqlExpr::Column { table, name } => { + if qualify { + BExpr::Column(nodedb_sql::planner::qualified_name(table.as_deref(), name)) + } else { + BExpr::Column(name.clone()) + } + } SqlExpr::Literal(v) => BExpr::Literal(sql_value_to_nodedb_value(v)), SqlExpr::BinaryOp { left, op, right } => BExpr::BinaryOp { - left: Box::new(sql_expr_to_bridge_expr(left)), + left: Box::new(convert_expr_inner(left, qualify)), op: match op { nodedb_sql::types::BinaryOp::Add => crate::bridge::expr_eval::BinaryOp::Add, nodedb_sql::types::BinaryOp::Sub => crate::bridge::expr_eval::BinaryOp::Sub, @@ -29,11 +52,14 @@ pub(super) fn sql_expr_to_bridge_expr(expr: &SqlExpr) -> crate::bridge::expr_eva nodedb_sql::types::BinaryOp::Or => crate::bridge::expr_eval::BinaryOp::Or, nodedb_sql::types::BinaryOp::Concat => crate::bridge::expr_eval::BinaryOp::Concat, }, - right: Box::new(sql_expr_to_bridge_expr(right)), + right: Box::new(convert_expr_inner(right, qualify)), }, SqlExpr::Function { name, args, .. } => BExpr::Function { name: name.clone(), - args: args.iter().map(sql_expr_to_bridge_expr).collect(), + args: args + .iter() + .map(|a| convert_expr_inner(a, qualify)) + .collect(), }, SqlExpr::Case { operand, @@ -42,14 +68,19 @@ pub(super) fn sql_expr_to_bridge_expr(expr: &SqlExpr) -> crate::bridge::expr_eva } => BExpr::Case { operand: operand .as_ref() - .map(|e| Box::new(sql_expr_to_bridge_expr(e))), + .map(|e| Box::new(convert_expr_inner(e, qualify))), when_thens: when_then .iter() - .map(|(w, t)| (sql_expr_to_bridge_expr(w), sql_expr_to_bridge_expr(t))) + .map(|(w, t)| { + ( + convert_expr_inner(w, qualify), + convert_expr_inner(t, qualify), + ) + }) .collect(), else_expr: else_expr .as_ref() - .map(|e| Box::new(sql_expr_to_bridge_expr(e))), + .map(|e| Box::new(convert_expr_inner(e, qualify))), }, SqlExpr::Cast { expr, to_type } => { let cast_type = match to_type.to_uppercase().as_str() { @@ -63,18 +94,18 @@ pub(super) fn sql_expr_to_bridge_expr(expr: &SqlExpr) -> crate::bridge::expr_eva _ => crate::bridge::expr_eval::CastType::String, }; BExpr::Cast { - expr: Box::new(sql_expr_to_bridge_expr(expr)), + expr: Box::new(convert_expr_inner(expr, qualify)), to_type: cast_type, } } SqlExpr::Wildcard => BExpr::Column("*".into()), // NOT e / -e → evaluator's Negate (handles both bool and numeric). - SqlExpr::UnaryOp { expr, .. } => BExpr::Negate(Box::new(sql_expr_to_bridge_expr(expr))), + SqlExpr::UnaryOp { expr, .. } => BExpr::Negate(Box::new(convert_expr_inner(expr, qualify))), // `e IS NULL` / `e IS NOT NULL` — direct passthrough. SqlExpr::IsNull { expr, negated } => BExpr::IsNull { - expr: Box::new(sql_expr_to_bridge_expr(expr)), + expr: Box::new(convert_expr_inner(expr, qualify)), negated: *negated, }, @@ -87,9 +118,9 @@ pub(super) fn sql_expr_to_bridge_expr(expr: &SqlExpr) -> crate::bridge::expr_eva high, negated, } => { - let e = sql_expr_to_bridge_expr(expr); - let l = sql_expr_to_bridge_expr(low); - let h = sql_expr_to_bridge_expr(high); + let e = convert_expr_inner(expr, qualify); + let l = convert_expr_inner(low, qualify); + let h = convert_expr_inner(high, qualify); if *negated { let lt = BExpr::BinaryOp { left: Box::new(e.clone()), @@ -134,7 +165,7 @@ pub(super) fn sql_expr_to_bridge_expr(expr: &SqlExpr) -> crate::bridge::expr_eva list, negated, } => { - let target = sql_expr_to_bridge_expr(expr); + let target = convert_expr_inner(expr, qualify); if list.is_empty() { // Empty list: `e IN ()` = false, `e NOT IN ()` = true. return BExpr::Literal(nodedb_types::Value::Bool(*negated)); @@ -157,7 +188,7 @@ pub(super) fn sql_expr_to_bridge_expr(expr: &SqlExpr) -> crate::bridge::expr_eva .map(|item| BExpr::BinaryOp { left: Box::new(target.clone()), op: eq_op, - right: Box::new(sql_expr_to_bridge_expr(item)), + right: Box::new(convert_expr_inner(item, qualify)), }) .reduce(|acc, next| BExpr::BinaryOp { left: Box::new(acc), @@ -178,8 +209,8 @@ pub(super) fn sql_expr_to_bridge_expr(expr: &SqlExpr) -> crate::bridge::expr_eva let call = BExpr::Function { name: "like".into(), args: vec![ - sql_expr_to_bridge_expr(expr), - sql_expr_to_bridge_expr(pattern), + convert_expr_inner(expr, qualify), + convert_expr_inner(pattern, qualify), ], }; if *negated { diff --git a/nodedb/src/control/planner/sql_plan_convert/filter.rs b/nodedb/src/control/planner/sql_plan_convert/filter.rs index 5ccb90a3..ed0162f8 100644 --- a/nodedb/src/control/planner/sql_plan_convert/filter.rs +++ b/nodedb/src/control/planner/sql_plan_convert/filter.rs @@ -122,7 +122,7 @@ fn filter_to_scan_filters(expr: &FilterExpr) -> Vec nodedb_query::scan_filter::ScanFilter { +pub(super) fn expr_filter(expr: &SqlExpr) -> nodedb_query::scan_filter::ScanFilter { nodedb_query::scan_filter::ScanFilter { field: String::new(), op: nodedb_query::scan_filter::FilterOp::Expr, @@ -132,6 +132,18 @@ fn expr_filter(expr: &SqlExpr) -> nodedb_query::scan_filter::ScanFilter { } } +/// Like [`expr_filter`] but qualifies column references with table names +/// for evaluation against join-merged documents. +pub(super) fn expr_filter_qualified(expr: &SqlExpr) -> nodedb_query::scan_filter::ScanFilter { + nodedb_query::scan_filter::ScanFilter { + field: String::new(), + op: nodedb_query::scan_filter::FilterOp::Expr, + value: nodedb_types::Value::Null, + clauses: Vec::new(), + expr: Some(super::expr::sql_expr_to_bridge_expr_qualified(expr)), + } +} + /// Convert a raw `SqlExpr` (from WHERE clause) to a `ScanFilter` list. /// /// Tries to produce simple, field-indexed filters for common cases (direct diff --git a/nodedb/src/control/planner/sql_plan_convert/scan.rs b/nodedb/src/control/planner/sql_plan_convert/scan.rs index 3d596fa5..c5ab8701 100644 --- a/nodedb/src/control/planner/sql_plan_convert/scan.rs +++ b/nodedb/src/control/planner/sql_plan_convert/scan.rs @@ -13,15 +13,49 @@ use super::aggregate::{ }; use super::convert::convert_one; use super::expr::convert_sort_keys; -use super::filter::serialize_filters; +use super::filter::{expr_filter_qualified, serialize_filters}; use super::scan_params::{ - HybridSearchParams, JoinPlanParams, ScanParams, SpatialScanParams, TimeseriesScanParams, + HybridSearchParams, JoinPlanParams, RecursiveScanParams, ScanParams, SpatialScanParams, + TimeseriesScanParams, }; use super::value::{ extract_time_range, row_to_msgpack, sql_value_to_bytes, sql_value_to_string, write_msgpack_array_header, }; +/// Serialize WHERE filters + non-equi join condition into a single `Vec`. +/// +/// The non-equi condition (from the ON clause) is appended as a +/// `FilterOp::Expr` ScanFilter so the join executor evaluates it on +/// merged rows alongside any post-join WHERE filters. +fn serialize_join_filters( + filters: &[Filter], + condition: &Option, +) -> crate::Result> { + match condition { + None => serialize_filters(filters), + Some(cond) => { + // Deserialize existing filters (if any), append condition, re-serialize. + let mut scan_filters: Vec = + if !filters.is_empty() { + let base = serialize_filters(filters)?; + if base.is_empty() { + Vec::new() + } else { + zerompk::from_msgpack(&base).unwrap_or_default() + } + } else { + Vec::new() + }; + scan_filters.push(expr_filter_qualified(cond)); + zerompk::to_msgpack_vec(&scan_filters).map_err(|e| crate::Error::Serialization { + format: "msgpack".into(), + detail: format!("join filter serialization: {e}"), + }) + } + } +} + pub(super) fn convert_scan(p: ScanParams<'_>) -> crate::Result> { let ScanParams { collection, @@ -153,6 +187,7 @@ pub(super) fn convert_join(p: JoinPlanParams<'_>) -> crate::Result) -> crate::Result) -> crate::Result) -> crate::Result, ) -> crate::Result> { - let vshard = VShardId::from_collection(collection); + let vshard = VShardId::from_collection(p.collection); Ok(vec![PhysicalTask { - tenant_id, + tenant_id: p.tenant_id, vshard_id: vshard, plan: PhysicalPlan::Query(QueryOp::RecursiveScan { - collection: collection.into(), - base_filters: serialize_filters(base_filters)?, - recursive_filters: serialize_filters(recursive_filters)?, - max_iterations: *max_iterations, - distinct: *distinct, - limit: *limit, + collection: p.collection.into(), + base_filters: serialize_filters(p.base_filters)?, + recursive_filters: serialize_filters(p.recursive_filters)?, + join_link: p.join_link.clone(), + max_iterations: *p.max_iterations, + distinct: *p.distinct, + limit: *p.limit, }), post_set_op: PostSetOp::None, }]) diff --git a/nodedb/src/control/planner/sql_plan_convert/scan_params.rs b/nodedb/src/control/planner/sql_plan_convert/scan_params.rs index adc1ad16..1ad4adea 100644 --- a/nodedb/src/control/planner/sql_plan_convert/scan_params.rs +++ b/nodedb/src/control/planner/sql_plan_convert/scan_params.rs @@ -26,6 +26,7 @@ pub(super) struct JoinPlanParams<'a> { pub right: &'a SqlPlan, pub on: &'a [(String, String)], pub join_type: &'a nodedb_sql::types::JoinType, + pub condition: &'a Option, pub limit: &'a usize, pub projection: &'a [Projection], pub filters: &'a [Filter], @@ -33,6 +34,18 @@ pub(super) struct JoinPlanParams<'a> { pub ctx: &'a ConvertContext, } +/// Parameters for `convert_recursive_scan`. +pub(super) struct RecursiveScanParams<'a> { + pub collection: &'a str, + pub base_filters: &'a [Filter], + pub recursive_filters: &'a [Filter], + pub join_link: &'a Option<(String, String)>, + pub max_iterations: &'a usize, + pub distinct: &'a bool, + pub limit: &'a usize, + pub tenant_id: TenantId, +} + /// Parameters for `convert_timeseries_scan`. pub(super) struct TimeseriesScanParams<'a> { pub collection: &'a str, diff --git a/nodedb/src/control/planner/sql_plan_convert/value/defaults.rs b/nodedb/src/control/planner/sql_plan_convert/value/defaults.rs index 09958e90..57d3939f 100644 --- a/nodedb/src/control/planner/sql_plan_convert/value/defaults.rs +++ b/nodedb/src/control/planner/sql_plan_convert/value/defaults.rs @@ -67,5 +67,30 @@ fn parse_parametric_or_literal(expr: &str, upper: &str) -> Option Option { + let sql_expr = nodedb_sql::parse_expr_string(expr).ok()?; + let folded = nodedb_sql::planner::const_fold::fold_constant_default(&sql_expr)?; + Some(sql_value_to_ndb(folded)) +} + +fn sql_value_to_ndb(v: nodedb_sql::types::SqlValue) -> nodedb_types::Value { + use nodedb_sql::types::SqlValue; + match v { + SqlValue::Null => nodedb_types::Value::Null, + SqlValue::Bool(b) => nodedb_types::Value::Bool(b), + SqlValue::Int(i) => nodedb_types::Value::Integer(i), + SqlValue::Float(f) => nodedb_types::Value::Float(f), + SqlValue::String(s) => nodedb_types::Value::String(s), + SqlValue::Bytes(b) => nodedb_types::Value::Bytes(b), + SqlValue::Array(a) => { + nodedb_types::Value::Array(a.into_iter().map(sql_value_to_ndb).collect()) + } + } } diff --git a/nodedb/src/control/planner/sql_plan_convert/value/time_range.rs b/nodedb/src/control/planner/sql_plan_convert/value/time_range.rs index e63a8739..309ded76 100644 --- a/nodedb/src/control/planner/sql_plan_convert/value/time_range.rs +++ b/nodedb/src/control/planner/sql_plan_convert/value/time_range.rs @@ -23,15 +23,15 @@ fn extract_time_bounds_from_filter(expr: &FilterExpr, min_ts: &mut i64, max_ts: FilterExpr::Comparison { field, op, value } if is_time_field(field) => { if let Some(ms) = sql_value_to_timestamp_ms(value) { match op { - nodedb_sql::types::CompareOp::Ge | nodedb_sql::types::CompareOp::Gt => { - if ms > *min_ts { - *min_ts = ms; - } + nodedb_sql::types::CompareOp::Ge | nodedb_sql::types::CompareOp::Gt + if ms > *min_ts => + { + *min_ts = ms; } - nodedb_sql::types::CompareOp::Le | nodedb_sql::types::CompareOp::Lt => { - if ms < *max_ts { - *max_ts = ms; - } + nodedb_sql::types::CompareOp::Le | nodedb_sql::types::CompareOp::Lt + if ms < *max_ts => + { + *max_ts = ms; } nodedb_sql::types::CompareOp::Eq => { *min_ts = ms; diff --git a/nodedb/src/control/security/catalog/checkpoints.rs b/nodedb/src/control/security/catalog/checkpoints.rs index 0c295378..f6d915af 100644 --- a/nodedb/src/control/security/catalog/checkpoints.rs +++ b/nodedb/src/control/security/catalog/checkpoints.rs @@ -110,7 +110,7 @@ impl SystemCatalog { } // Sort by created_at descending (most recent first). - records.sort_by(|a, b| b.created_at.cmp(&a.created_at)); + records.sort_by_key(|r| std::cmp::Reverse(r.created_at)); if records.len() > limit && limit > 0 { records.truncate(limit); diff --git a/nodedb/src/control/security/explain.rs b/nodedb/src/control/security/explain.rs index 24ff8d0d..a5ccb90d 100644 --- a/nodedb/src/control/security/explain.rs +++ b/nodedb/src/control/security/explain.rs @@ -166,13 +166,12 @@ pub fn lint_predicate(predicate: &super::predicate::RlsPredicate) -> Vec super::predicate::RlsPredicate::AlwaysFalse => { warnings.push("contradiction: predicate is always false (blocks everything)".into()); } - super::predicate::RlsPredicate::Compare { value, .. } => { - if !value.is_auth_ref() && matches!(value, super::predicate::PredicateValue::Literal(_)) - { - warnings.push( - "static predicate: no $auth reference — same result for all users".into(), - ); - } + super::predicate::RlsPredicate::Compare { value, .. } + if !value.is_auth_ref() + && matches!(value, super::predicate::PredicateValue::Literal(_)) => + { + warnings + .push("static predicate: no $auth reference — same result for all users".into()); } super::predicate::RlsPredicate::And(children) | super::predicate::RlsPredicate::Or(children) => { diff --git a/nodedb/src/control/server/native/dispatch/plan_builder/query.rs b/nodedb/src/control/server/native/dispatch/plan_builder/query.rs index 00259c83..9f767c59 100644 --- a/nodedb/src/control/server/native/dispatch/plan_builder/query.rs +++ b/nodedb/src/control/server/native/dispatch/plan_builder/query.rs @@ -16,6 +16,7 @@ pub(crate) fn build_recursive_scan( collection: collection.to_string(), base_filters, recursive_filters: Vec::new(), + join_link: None, max_iterations: 100, distinct: true, limit, diff --git a/nodedb/src/control/server/native/session.rs b/nodedb/src/control/server/native/session.rs index d143859d..105945eb 100644 --- a/nodedb/src/control/server/native/session.rs +++ b/nodedb/src/control/server/native/session.rs @@ -418,17 +418,12 @@ fn chunk_large_response( }; let sample_bytes = codec::encode_response(&sample_resp, format)?; let sample_count = total_rows.min(100); - let per_row_estimate = if sample_count > 0 { - sample_bytes.len() / sample_count - } else { - 256 // fallback - }; + let per_row_estimate = sample_bytes.len().checked_div(sample_count).unwrap_or(256); - let rows_per_chunk = if per_row_estimate > 0 { - (target_size / per_row_estimate).max(1) - } else { - 1000 - }; + let rows_per_chunk = target_size + .checked_div(per_row_estimate) + .map(|v| v.max(1)) + .unwrap_or(1000); let mut frames = Vec::new(); let chunks: Vec<_> = rows.chunks(rows_per_chunk).collect(); diff --git a/nodedb/src/control/server/pgwire/ddl/collection/check_constraint.rs b/nodedb/src/control/server/pgwire/ddl/collection/check_constraint.rs index da6d37e5..6c037a87 100644 --- a/nodedb/src/control/server/pgwire/ddl/collection/check_constraint.rs +++ b/nodedb/src/control/server/pgwire/ddl/collection/check_constraint.rs @@ -271,24 +271,24 @@ fn restructure_subquery_check(expr: &str) -> RestructuredCheck { /// Converts `NEW.amount > 0` → `amount > 0` so the expression can be parsed /// as bare column references by `parse_generated_expr`. fn strip_new_prefix(sql: &str) -> String { + let chars: Vec = sql.chars().collect(); let mut result = String::with_capacity(sql.len()); - let upper = sql.to_uppercase(); - let bytes = sql.as_bytes(); let mut i = 0; - while i < bytes.len() { - if i + 4 <= bytes.len() && upper[i..].starts_with("NEW.") { - // Check word boundary before. - if i > 0 && (bytes[i - 1].is_ascii_alphanumeric() || bytes[i - 1] == b'_') { - result.push(bytes[i] as char); - i += 1; + while i < chars.len() { + if i + 4 <= chars.len() { + let window: String = chars[i..i + 4].iter().collect(); + if window.eq_ignore_ascii_case("NEW.") { + if i > 0 && (chars[i - 1].is_ascii_alphanumeric() || chars[i - 1] == '_') { + result.push(chars[i]); + i += 1; + continue; + } + i += 4; continue; } - // Skip "NEW." (4 chars). - i += 4; - continue; } - result.push(bytes[i] as char); + result.push(chars[i]); i += 1; } result @@ -330,33 +330,35 @@ fn substitute_new_refs(sql: &str, fields: &HashMap) /// Replace any remaining `NEW.xxx` references (not matched by known fields) with NULL. fn replace_remaining_new_refs(text: &str) -> String { - let upper = text.to_uppercase(); + let chars: Vec = text.chars().collect(); let mut result = String::with_capacity(text.len()); let mut i = 0; - let bytes = text.as_bytes(); - while i < bytes.len() { + while i < chars.len() { // Check for "NEW." prefix (case insensitive). - if i + 4 <= bytes.len() && upper[i..].starts_with("NEW.") { - // Check word boundary before. - if i > 0 && (bytes[i - 1].is_ascii_alphanumeric() || bytes[i - 1] == b'_') { - result.push(bytes[i] as char); - i += 1; - continue; - } - // Find the end of the identifier after "NEW.". - let start = i + 4; - let mut end = start; - while end < bytes.len() && (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_') { - end += 1; - } - if end > start { - result.push_str("NULL"); - i = end; - continue; + if i + 4 <= chars.len() { + let window: String = chars[i..i + 4].iter().collect(); + if window.eq_ignore_ascii_case("NEW.") { + if i > 0 && (chars[i - 1].is_ascii_alphanumeric() || chars[i - 1] == '_') { + result.push(chars[i]); + i += 1; + continue; + } + // Find the end of the identifier after "NEW.". + let start = i + 4; + let mut end = start; + while end < chars.len() && (chars[end].is_ascii_alphanumeric() || chars[end] == '_') + { + end += 1; + } + if end > start { + result.push_str("NULL"); + i = end; + continue; + } } } - result.push(bytes[i] as char); + result.push(chars[i]); i += 1; } result diff --git a/nodedb/src/control/server/pgwire/ddl/constraint/validate.rs b/nodedb/src/control/server/pgwire/ddl/constraint/validate.rs index 0d20ccb6..3f99fb86 100644 --- a/nodedb/src/control/server/pgwire/ddl/constraint/validate.rs +++ b/nodedb/src/control/server/pgwire/ddl/constraint/validate.rs @@ -29,21 +29,23 @@ pub(super) fn validate_subquery_pattern(check_sql: &str) -> PgWireResult<()> { /// Strip `NEW.` prefix for validation parsing. pub(super) fn strip_new_prefix_for_validation(sql: &str) -> String { - let upper = sql.to_uppercase(); - let bytes = sql.as_bytes(); + let chars: Vec = sql.chars().collect(); let mut result = String::with_capacity(sql.len()); let mut i = 0; - while i < bytes.len() { - if i + 4 <= bytes.len() && upper[i..].starts_with("NEW.") { - if i > 0 && (bytes[i - 1].is_ascii_alphanumeric() || bytes[i - 1] == b'_') { - result.push(bytes[i] as char); - i += 1; + while i < chars.len() { + if i + 4 <= chars.len() { + let window: String = chars[i..i + 4].iter().collect(); + if window.eq_ignore_ascii_case("NEW.") { + if i > 0 && (chars[i - 1].is_ascii_alphanumeric() || chars[i - 1] == '_') { + result.push(chars[i]); + i += 1; + continue; + } + i += 4; continue; } - i += 4; - continue; } - result.push(bytes[i] as char); + result.push(chars[i]); i += 1; } result diff --git a/nodedb/src/control/server/pgwire/ddl/materialized_view/parse.rs b/nodedb/src/control/server/pgwire/ddl/materialized_view/parse.rs index 395f9652..9b832c14 100644 --- a/nodedb/src/control/server/pgwire/ddl/materialized_view/parse.rs +++ b/nodedb/src/control/server/pgwire/ddl/materialized_view/parse.rs @@ -50,16 +50,11 @@ pub fn parse_create_mv(sql: &str) -> PgWireResult<(String, String, String, Strin .ok_or_else(|| sqlstate_error("42601", "expected AS SELECT ... clause"))?; let query_start = after_on_start + as_pos + KW_AS.len(); - // Find end of query: WITH clause or end of string. - let remaining = &upper[query_start..]; - let with_pos = remaining.find(" WITH").or_else(|| { - if remaining.trim_start().starts_with("WITH") { - Some(0) - } else { - None - } - }); - let query_end = with_pos.map(|p| query_start + p).unwrap_or(sql.len()); + // Find end of query: a trailing `WITH (...)` options clause, or end of string. + // We must not match a CTE `WITH cte AS (...)` inside the SELECT body. + // The options clause is always `WITH (` — a CTE `WITH` is followed by + // an identifier, never `(`. + let query_end = find_trailing_with_options(&upper, query_start).unwrap_or(sql.len()); let query_sql = sql[query_start..query_end].trim().to_string(); if query_sql.is_empty() { @@ -71,6 +66,35 @@ pub fn parse_create_mv(sql: &str) -> PgWireResult<(String, String, String, Strin Ok((name, source, query_sql, refresh_mode)) } +/// Find a trailing `WITH (...)` options clause that is NOT a CTE. +/// +/// Scans backward from the end of the string for `WITH` followed (after +/// optional whitespace) by `(`. CTE syntax is `WITH AS (...)` — the +/// word after `WITH` is an identifier, not `(`. +fn find_trailing_with_options(upper: &str, query_start: usize) -> Option { + let region = &upper[query_start..]; + // Search backward for the last " WITH" or "WITH" that is followed by "(". + let mut search_end = region.len(); + loop { + let pos = region[..search_end].rfind("WITH")?; + // Verify word boundary before WITH. + if pos > 0 { + let before = region.as_bytes()[pos - 1]; + if before.is_ascii_alphanumeric() || before == b'_' { + search_end = pos; + continue; + } + } + // Check that WITH is followed by `(` (possibly with whitespace). + let after_with = region[pos + 4..].trim_start(); + if after_with.starts_with('(') { + return Some(query_start + pos); + } + // This WITH is a CTE or identifier — keep searching backward. + search_end = pos; + } +} + /// Extract refresh mode from WITH clause. fn extract_refresh_mode(upper: &str, sql: &str) -> String { let with_pos = match upper.rfind("WITH") { diff --git a/nodedb/src/control/server/pgwire/ddl/rls/parse.rs b/nodedb/src/control/server/pgwire/ddl/rls/parse.rs index ac7bcbfc..00aad0e2 100644 --- a/nodedb/src/control/server/pgwire/ddl/rls/parse.rs +++ b/nodedb/src/control/server/pgwire/ddl/rls/parse.rs @@ -70,10 +70,7 @@ pub fn parse_create_rls_policy( .map(|i| using_idx + 1 + i) .unwrap_or(parts.len()); - let predicate_str = parts[using_idx + 1..pred_end] - .join(" ") - .trim_matches(|c: char| c == '(' || c == ')') - .to_string(); + let predicate_str = strip_outer_parens(&parts[using_idx + 1..pred_end].join(" ")); let is_restrictive = parts[pred_end..] .iter() @@ -114,7 +111,7 @@ pub fn parse_create_rls_policy( let field = pred_parts[0]; let op = pred_parts[1]; - let value_str = pred_parts[2..].join(" ").trim_matches('\'').to_string(); + let value_str = strip_single_quotes(&pred_parts[2..].join(" ")); let filter = crate::bridge::scan_filter::ScanFilter { field: field.to_string(), @@ -161,3 +158,48 @@ pub fn parse_create_rls_policy( on_deny, }) } + +/// Strip at most one matched pair of outer single quotes. +/// +/// `"'hello'"` → `"hello"`, `"no_quotes"` → `"no_quotes"`. +fn strip_single_quotes(s: &str) -> String { + let trimmed = s.trim(); + if trimmed.len() >= 2 && trimmed.starts_with('\'') && trimmed.ends_with('\'') { + trimmed[1..trimmed.len() - 1].to_string() + } else { + trimmed.to_string() + } +} + +/// Strip at most one matched pair of outer parentheses. Unlike +/// `trim_matches('(' | ')')`, this preserves balanced inner parens. +/// +/// `"((x > 0) AND (y = 1))"` → `"(x > 0) AND (y = 1)"` +/// `"(x > 0)"` → `"x > 0"` +/// `"x > 0"` → `"x > 0"` (no outer parens) +fn strip_outer_parens(s: &str) -> String { + let trimmed = s.trim(); + if trimmed.starts_with('(') && trimmed.ends_with(')') { + // Verify the outer parens are actually matched (not two separate groups). + let inner = &trimmed[1..trimmed.len() - 1]; + let mut depth = 0i32; + let mut balanced = true; + for ch in inner.chars() { + match ch { + '(' => depth += 1, + ')' => { + depth -= 1; + if depth < 0 { + balanced = false; + break; + } + } + _ => {} + } + } + if balanced && depth == 0 { + return inner.trim().to_string(); + } + } + trimmed.to_string() +} diff --git a/nodedb/src/control/server/pgwire/ddl/trigger/parse.rs b/nodedb/src/control/server/pgwire/ddl/trigger/parse.rs index d70e829a..5797cc16 100644 --- a/nodedb/src/control/server/pgwire/ddl/trigger/parse.rs +++ b/nodedb/src/control/server/pgwire/ddl/trigger/parse.rs @@ -320,21 +320,42 @@ fn extract_dollar_quoted_body(s: &str) -> Option<(&str, String)> { fn find_begin_pos(s: &str) -> Option { let upper = s.to_uppercase(); - let mut search_from = 0; - loop { - let pos = upper[search_from..].find("BEGIN")?; - let abs_pos = search_from + pos; - let before_ok = abs_pos == 0 - || !s.as_bytes()[abs_pos - 1].is_ascii_alphanumeric() - && s.as_bytes()[abs_pos - 1] != b'_'; - let after_pos = abs_pos + 5; - let after_ok = after_pos >= s.len() - || !s.as_bytes()[after_pos].is_ascii_alphanumeric() && s.as_bytes()[after_pos] != b'_'; - if before_ok && after_ok { - return Some(abs_pos); + let bytes = s.as_bytes(); + let mut i = 0; + + while i < bytes.len() { + // Skip single-quoted string literals so that 'BEGIN' inside a + // string (e.g. in a WHEN clause) is not mistaken for the body start. + if bytes[i] == b'\'' { + i += 1; + while i < bytes.len() { + if bytes[i] == b'\'' { + i += 1; + // Handle '' escape. + if i < bytes.len() && bytes[i] == b'\'' { + i += 1; + continue; + } + break; + } + i += 1; + } + continue; + } + + // Check for "BEGIN" keyword at this position. + if i + 5 <= bytes.len() && &upper[i..i + 5] == "BEGIN" { + let before_ok = i == 0 || !bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_'; + let after_pos = i + 5; + let after_ok = after_pos >= bytes.len() + || !bytes[after_pos].is_ascii_alphanumeric() && bytes[after_pos] != b'_'; + if before_ok && after_ok { + return Some(i); + } } - search_from = abs_pos + 5; + i += 1; } + None } #[cfg(test)] diff --git a/nodedb/src/control/server/pgwire/handler/plan.rs b/nodedb/src/control/server/pgwire/handler/plan.rs index 955e9567..8d4c8e47 100644 --- a/nodedb/src/control/server/pgwire/handler/plan.rs +++ b/nodedb/src/control/server/pgwire/handler/plan.rs @@ -124,6 +124,7 @@ pub(super) fn describe_plan(plan: &PhysicalPlan) -> PlanKind { | PhysicalPlan::Query(QueryOp::FacetCounts { .. }) | PhysicalPlan::Query(QueryOp::HashJoin { .. }) | PhysicalPlan::Query(QueryOp::InlineHashJoin { .. }) + | PhysicalPlan::Query(QueryOp::RecursiveScan { .. }) | PhysicalPlan::Graph(GraphOp::Algo { .. }) | PhysicalPlan::Graph(GraphOp::Match { .. }) => PlanKind::MultiRow, diff --git a/nodedb/src/control/server/resp/handler.rs b/nodedb/src/control/server/resp/handler.rs index 121e19cb..25d7fe9b 100644 --- a/nodedb/src/control/server/resp/handler.rs +++ b/nodedb/src/control/server/resp/handler.rs @@ -259,27 +259,23 @@ async fn handle_scan(cmd: &RespCommand, session: &RespSession, state: &SharedSta i += 2; } // NodeDB extension: SCAN 0 FILTER = - Some(ref flag) if flag == "FILTER" => { + Some(ref flag) if flag == "FILTER" && i + 4 <= cmd.argc() => { // Parse simple "field = value" predicate (needs 4 args: FILTER field = value). - if i + 4 <= cmd.argc() { - let field = cmd.arg_str(i + 1).unwrap_or(""); - let _op = cmd.arg_str(i + 2).unwrap_or(""); // "=" expected - let value = cmd.arg_str(i + 3).unwrap_or(""); - let scan_filter = serde_json::json!([{ - "field": field, - "op": "eq", - "value": value, - }]); - match nodedb_types::json_to_msgpack(&scan_filter) { - Ok(bytes) => filter_bytes = bytes, - Err(_) => { - return RespValue::err("ERR filter serialization failed"); - } + let field = cmd.arg_str(i + 1).unwrap_or(""); + let _op = cmd.arg_str(i + 2).unwrap_or(""); // "=" expected + let value = cmd.arg_str(i + 3).unwrap_or(""); + let scan_filter = serde_json::json!([{ + "field": field, + "op": "eq", + "value": value, + }]); + match nodedb_types::json_to_msgpack(&scan_filter) { + Ok(bytes) => filter_bytes = bytes, + Err(_) => { + return RespValue::err("ERR filter serialization failed"); } - i += 4; - } else { - i += 1; } + i += 4; } _ => { i += 1; diff --git a/nodedb/src/control/server/sync/listener.rs b/nodedb/src/control/server/sync/listener.rs index a974844b..daa451ab 100644 --- a/nodedb/src/control/server/sync/listener.rs +++ b/nodedb/src/control/server/sync/listener.rs @@ -316,9 +316,9 @@ async fn handle_sync_session( } } Ok(Message::Ping(data)) => { - if ws.send(Message::Pong(data)).await.is_err() { + let Ok(_) = ws.send(Message::Pong(data)).await else { break; - } + }; } Ok(Message::Close(_)) => break, Err(e) => { diff --git a/nodedb/src/data/executor/dispatch/other.rs b/nodedb/src/data/executor/dispatch/other.rs index 1e1ca617..5adc2c77 100644 --- a/nodedb/src/data/executor/dispatch/other.rs +++ b/nodedb/src/data/executor/dispatch/other.rs @@ -139,6 +139,7 @@ impl CoreLoop { collection, base_filters, recursive_filters, + join_link, max_iterations, distinct, limit, @@ -148,6 +149,7 @@ impl CoreLoop { collection, base_filters, recursive_filters, + join_link.as_ref(), *max_iterations, *distinct, *limit, diff --git a/nodedb/src/data/executor/handlers/document/text_extract.rs b/nodedb/src/data/executor/handlers/document/text_extract.rs index 9a4ba722..2646771b 100644 --- a/nodedb/src/data/executor/handlers/document/text_extract.rs +++ b/nodedb/src/data/executor/handlers/document/text_extract.rs @@ -14,10 +14,8 @@ pub fn extract_indexable_text(doc: &serde_json::Value) -> String { pub(super) fn collect_text(val: &serde_json::Value, parts: &mut Vec) { match val { - serde_json::Value::String(s) => { - if !s.is_empty() { - parts.push(s.clone()); - } + serde_json::Value::String(s) if !s.is_empty() => { + parts.push(s.clone()); } serde_json::Value::Object(map) => { for v in map.values() { diff --git a/nodedb/src/data/executor/handlers/facet.rs b/nodedb/src/data/executor/handlers/facet.rs index d5a09d78..c5ef5eae 100644 --- a/nodedb/src/data/executor/handlers/facet.rs +++ b/nodedb/src/data/executor/handlers/facet.rs @@ -151,7 +151,7 @@ impl CoreLoop { } let mut result: Vec<(String, usize)> = counts.into_iter().collect(); - result.sort_by(|a, b| b.1.cmp(&a.1)); // Count descending. + result.sort_by_key(|r| std::cmp::Reverse(r.1)); // Count descending. result } } diff --git a/nodedb/src/data/executor/handlers/recursive.rs b/nodedb/src/data/executor/handlers/recursive.rs index 6fc48579..6c2fd1b9 100644 --- a/nodedb/src/data/executor/handlers/recursive.rs +++ b/nodedb/src/data/executor/handlers/recursive.rs @@ -1,8 +1,9 @@ //! Recursive CTE handler: iterative fixed-point execution. //! //! Executes the base query once to seed the working table, then -//! repeatedly executes the recursive query until no new rows are -//! produced (fixed point) or max_iterations is reached. +//! repeatedly joins the collection against the working table via +//! the `join_link` until no new rows are produced (fixed point) +//! or `max_iterations` is reached. use std::collections::HashSet; @@ -16,10 +17,11 @@ impl CoreLoop { /// /// Algorithm: /// 1. Seed: scan collection with base_filters → working_table - /// 2. Loop: scan collection with recursive_filters, join against working_table - /// 3. Add new rows to working_table - /// 4. Repeat until no new rows or max_iterations reached - /// 5. Return all accumulated rows + /// 2. Loop: for each row in collection, check if `join_link.0` value + /// matches any `join_link.1` value in the working_table → new matches + /// 3. New matches become the working table for the next iteration + /// 4. Accumulate all results + /// 5. Repeat until no new rows or max_iterations reached #[allow(clippy::too_many_arguments)] pub(in crate::data::executor) fn execute_recursive_scan( &mut self, @@ -28,6 +30,7 @@ impl CoreLoop { collection: &str, base_filters: &[u8], recursive_filters: &[u8], + join_link: Option<&(String, String)>, max_iterations: usize, distinct: bool, limit: usize, @@ -66,6 +69,17 @@ impl CoreLoop { } }; + // Check if the collection uses strict (Binary Tuple) encoding. + let config_key = format!("{tid}:{collection}"); + let strict_schema = self.doc_configs.get(&config_key).and_then(|c| { + if let crate::bridge::physical_plan::StorageMode::Strict { ref schema } = c.storage_mode + { + Some(schema.clone()) + } else { + None + } + }); + // Scan all documents once (used for both base and recursive steps). let all_docs = match self.sparse.scan_documents(tid, collection, scan_limit) { Ok(d) => d, @@ -79,15 +93,40 @@ impl CoreLoop { } }; - // Step 1: Seed working table with base query results (raw msgpack). + // Convert raw bytes to msgpack. For strict docs, this requires the schema. + let to_msgpack = |value: &[u8]| -> Option> { + if let Some(ref schema) = strict_schema { + super::super::strict_format::binary_tuple_to_msgpack(value, schema) + } else { + Some(super::super::doc_format::json_to_msgpack(value)) + } + }; + + // Step 1: Seed working table with base query results. let mut results: Vec> = Vec::new(); let mut seen_keys: HashSet = HashSet::new(); + tracing::debug!( + core = self.core_id, + %collection, + all_docs = all_docs.len(), + base_preds = base_preds.len(), + strict = strict_schema.is_some(), + ?join_link, + "recursive CTE: starting seed" + ); + for (_doc_id, value) in &all_docs { - if !base_preds.iter().all(|f| f.matches_binary(value)) { + let mp = match to_msgpack(value) { + Some(m) => m, + None => { + tracing::debug!(core = self.core_id, "to_msgpack returned None"); + continue; + } + }; + if !base_preds.iter().all(|f| f.matches_binary(&mp)) { continue; } - let mp = super::super::doc_format::json_to_msgpack(value); let key = if distinct { nodedb_types::msgpack_to_json_string(&mp).unwrap_or_default() } else { @@ -98,55 +137,135 @@ impl CoreLoop { } } + tracing::debug!( + core = self.core_id, + seed_count = results.len(), + "recursive CTE: seed complete" + ); + // Step 2: Iterate recursive step until fixed point. - let mut prev_count = 0; - for iteration in 0..max_iterations { - if results.len() >= limit || results.len() == prev_count { - break; - } - prev_count = results.len(); - - // The "working table" for this iteration is the rows added in the - // previous iteration. For the recursive step, we scan the collection - // and filter by recursive predicates AND check that the row relates - // to existing working table rows (by matching any field). - // - // In a full SQL recursive CTE, the recursive term references the - // CTE name. Here, we approximate by applying recursive filters to - // the full collection and adding any new matching rows. - let mut new_rows = Vec::new(); - for (doc_id, value) in &all_docs { - if results.len() + new_rows.len() >= limit { + if let Some((collection_field, working_field)) = join_link { + // Working-table hash-join: each iteration finds collection rows + // where `collection_field` matches a `working_field` value from + // the previous iteration's new rows. + let mut frontier = results.clone(); + + for iteration in 0..max_iterations { + if results.len() >= limit || frontier.is_empty() { break; } - if !recursive_preds.iter().all(|f| f.matches_binary(value)) { - continue; + + // Build hash set of working_field values from the frontier. + let frontier_values: HashSet = frontier + .iter() + .filter_map(|row| extract_field_string(row, working_field)) + .collect(); + + if frontier_values.is_empty() { + break; } - let mp = super::super::doc_format::json_to_msgpack(value); - let key = if distinct { - nodedb_types::msgpack_to_json_string(&mp).unwrap_or_default() - } else { - doc_id.clone() - }; - if !distinct || seen_keys.insert(key) { - new_rows.push(mp); + + let mut new_rows = Vec::new(); + for (_doc_id, value) in &all_docs { + if results.len() + new_rows.len() >= limit { + break; + } + + let mp = match to_msgpack(value) { + Some(m) => m, + None => continue, + }; + + // Apply recursive filters (WHERE clause from recursive branch). + if !recursive_preds.iter().all(|f| f.matches_binary(&mp)) { + continue; + } + + // Check join link: collection_field value must be in frontier. + let field_val = match extract_field_string(&mp, collection_field) { + Some(v) => v, + None => continue, + }; + if !frontier_values.contains(&field_val) { + continue; + } + + let key = if distinct { + nodedb_types::msgpack_to_json_string(&mp).unwrap_or_default() + } else { + String::new() + }; + if !distinct || seen_keys.insert(key) { + new_rows.push(mp); + } + } + + if new_rows.is_empty() { + break; } - } - if new_rows.is_empty() { - break; + tracing::debug!( + core = self.core_id, + iteration, + new_rows = new_rows.len(), + total = results.len() + new_rows.len(), + "recursive CTE iteration (join-link)" + ); + frontier = new_rows.clone(); + results.extend(new_rows); } + } else { + // No join link — fall back to filter-only iteration (original behavior). + let mut prev_count = 0; + for iteration in 0..max_iterations { + if results.len() >= limit || results.len() == prev_count { + break; + } + prev_count = results.len(); + + let mut new_rows = Vec::new(); + for (doc_id, value) in &all_docs { + if results.len() + new_rows.len() >= limit { + break; + } + let mp = match to_msgpack(value) { + Some(m) => m, + None => continue, + }; + if !recursive_preds.iter().all(|f| f.matches_binary(&mp)) { + continue; + } + let key = if distinct { + nodedb_types::msgpack_to_json_string(&mp).unwrap_or_default() + } else { + doc_id.clone() + }; + if !distinct || seen_keys.insert(key) { + new_rows.push(mp); + } + } - tracing::debug!( - core = self.core_id, - iteration, - new_rows = new_rows.len(), - total = results.len() + new_rows.len(), - "recursive CTE iteration" - ); - results.extend(new_rows); + if new_rows.is_empty() { + break; + } + + tracing::debug!( + core = self.core_id, + iteration, + new_rows = new_rows.len(), + total = results.len() + new_rows.len(), + "recursive CTE iteration (filter-only)" + ); + results.extend(new_rows); + } } + tracing::debug!( + core = self.core_id, + total = results.len(), + "recursive CTE: iteration complete" + ); + // Truncate to limit. results.truncate(limit); @@ -156,14 +275,24 @@ impl CoreLoop { for row in &results { payload.extend_from_slice(row); } - match Ok::, crate::Error>(payload) { - Ok(payload) => self.response_with_payload(task, payload), - Err(e) => self.response_error( - task, - ErrorCode::Internal { - detail: e.to_string(), - }, - ), + self.response_with_payload(task, payload) + } +} + +/// Extract a field value from a msgpack document as a string for hash lookup. +fn extract_field_string(msgpack_doc: &[u8], field_name: &str) -> Option { + let value = nodedb_types::value_from_msgpack(msgpack_doc).ok()?; + match &value { + nodedb_types::Value::Object(map) => { + let v = map.get(field_name)?; + match v { + nodedb_types::Value::String(s) => Some(s.clone()), + nodedb_types::Value::Integer(i) => Some(i.to_string()), + nodedb_types::Value::Float(f) => Some(f.to_string()), + nodedb_types::Value::Bool(b) => Some(b.to_string()), + _ => Some(format!("{v:?}")), + } } + _ => None, } } diff --git a/nodedb/src/engine/graph/algo/kcore.rs b/nodedb/src/engine/graph/algo/kcore.rs index 8c67420f..07d1a79c 100644 --- a/nodedb/src/engine/graph/algo/kcore.rs +++ b/nodedb/src/engine/graph/algo/kcore.rs @@ -91,7 +91,7 @@ pub fn run(csr: &CsrIndex) -> AlgoResultBatch { // Build result sorted by coreness descending. let mut scored: Vec<(usize, usize)> = coreness.into_iter().enumerate().collect(); - scored.sort_by(|a, b| b.1.cmp(&a.1)); + scored.sort_by_key(|&(_, k)| std::cmp::Reverse(k)); let mut batch = AlgoResultBatch::new(GraphAlgorithm::KCore); for (node, k) in scored { diff --git a/nodedb/src/engine/sparse/btree_scan.rs b/nodedb/src/engine/sparse/btree_scan.rs index ad48316f..ea870dbd 100644 --- a/nodedb/src/engine/sparse/btree_scan.rs +++ b/nodedb/src/engine/sparse/btree_scan.rs @@ -190,7 +190,7 @@ impl SparseEngine { } let mut result: Vec<(String, usize)> = groups.into_iter().collect(); - result.sort_by(|a, b| b.1.cmp(&a.1)); // Sort by count descending (most popular first). + result.sort_by_key(|r| std::cmp::Reverse(r.1)); // Sort by count descending (most popular first). debug!( collection, field, diff --git a/nodedb/src/engine/timeseries/compound_filter.rs b/nodedb/src/engine/timeseries/compound_filter.rs index c22f284a..379a84a1 100644 --- a/nodedb/src/engine/timeseries/compound_filter.rs +++ b/nodedb/src/engine/timeseries/compound_filter.rs @@ -93,7 +93,7 @@ impl CompoundTagIndex { .iter() .map(|(k, &v)| (k.clone(), v)) .collect(); - combos.sort_by(|a, b| b.1.cmp(&a.1)); + combos.sort_by_key(|c| std::cmp::Reverse(c.1)); combos.truncate(n); combos } diff --git a/nodedb/tests/procedure_execution.rs b/nodedb/tests/procedure_execution.rs index 5c943d8d..faae57e3 100644 --- a/nodedb/tests/procedure_execution.rs +++ b/nodedb/tests/procedure_execution.rs @@ -390,10 +390,10 @@ fn execution_budget_exhaustion() { } #[test] -fn execution_budget_unlimited() { +fn execution_budget_trigger_default() { use nodedb::control::planner::procedural::executor::fuel::ExecutionBudget; - let mut budget = ExecutionBudget::unlimited(); + let mut budget = ExecutionBudget::trigger_default(); for _ in 0..1000 { assert!(budget.consume_iteration().is_ok()); } diff --git a/nodedb/tests/sql_arithmetic_overflow.rs b/nodedb/tests/sql_arithmetic_overflow.rs new file mode 100644 index 00000000..d28685bf --- /dev/null +++ b/nodedb/tests/sql_arithmetic_overflow.rs @@ -0,0 +1,198 @@ +//! Integration coverage for integer overflow handling in SQL expressions. +//! +//! The const-folder and procedural executor must detect integer overflow +//! and return an error rather than panicking (debug) or silently wrapping +//! (release). Float divide-by-zero must return an error, not ±Inf. + +mod common; + +use common::pgwire_harness::TestServer; + +// --------------------------------------------------------------------------- +// Const-folder overflow (nodedb-sql planner) +// --------------------------------------------------------------------------- + +/// `i64::MAX + 1` in a constant expression must not panic (debug) or wrap +/// to a negative number (release). An error or null are both acceptable. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn const_fold_addition_overflow_does_not_wrap() { + let server = TestServer::start().await; + + let result = server.query_text("SELECT 9223372036854775807 + 1").await; + + match result { + Err(_) => { /* error is acceptable */ } + Ok(rows) => { + if let Some(val) = rows.first() { + assert!( + !val.contains("-9223372036854775808"), + "i64::MAX + 1 must not silently wrap to i64::MIN: got {val}" + ); + } + } + } +} + +/// `i64::MAX * 2` must not panic or wrap. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn const_fold_multiplication_overflow_does_not_wrap() { + let server = TestServer::start().await; + + let result = server.query_text("SELECT 9223372036854775807 * 2").await; + + match result { + Err(_) => { /* error is acceptable */ } + Ok(rows) => { + if let Some(val) = rows.first() { + // Wrapped value would be -2. + assert!( + !val.contains("\"-2\""), + "i64::MAX * 2 must not silently wrap: got {val}" + ); + } + } + } +} + +/// `i64::MIN - 1` must not panic or wrap. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn const_fold_subtraction_overflow_does_not_wrap() { + let server = TestServer::start().await; + + let result = server.query_text("SELECT -9223372036854775808 - 1").await; + + match result { + Err(_) => { /* error is acceptable */ } + Ok(rows) => { + if let Some(val) = rows.first() { + // Wrapped value would be i64::MAX = 9223372036854775807. + assert!( + !val.contains("9223372036854775807"), + "i64::MIN - 1 must not silently wrap to i64::MAX: got {val}" + ); + } + } + } +} + +// --------------------------------------------------------------------------- +// Const-folder in INSERT context +// --------------------------------------------------------------------------- + +/// Overflow in an INSERT VALUES expression should be caught before storage. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn const_fold_overflow_in_insert_values() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION overflow_tbl TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + v BIGINT)", + ) + .await + .unwrap(); + + let result = server + .exec("INSERT INTO overflow_tbl (id, v) VALUES ('k', 9223372036854775807 * 2)") + .await; + + assert!( + result.is_err(), + "INSERT with overflowing constant should fail, not store wrapped value" + ); +} + +// --------------------------------------------------------------------------- +// Procedural executor overflow (triggers / DO blocks) +// --------------------------------------------------------------------------- + +/// Integer overflow in a DO block's variable arithmetic should error, not +/// panic or silently wrap. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn procedural_eval_integer_overflow_returns_error() { + let server = TestServer::start().await; + + server + .exec("CREATE COLLECTION proc_ov TYPE DOCUMENT STRICT (id TEXT PRIMARY KEY, v BIGINT)") + .await + .unwrap(); + + // A DO block (or procedure) that overflows during evaluation. + let result = server + .exec( + "DO $$ \ + DECLARE x BIGINT := 9223372036854775807; \ + BEGIN \ + x := x + 1; \ + INSERT INTO proc_ov (id, v) VALUES ('k', x); \ + END $$", + ) + .await; + + assert!( + result.is_err(), + "integer overflow in procedural block should error, not wrap" + ); +} + +/// `i64::MIN / -1` is undefined behavior in two's complement and must error. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn procedural_eval_min_div_neg1_returns_error() { + let server = TestServer::start().await; + + let result = server + .exec( + "DO $$ \ + DECLARE x BIGINT := -9223372036854775808; \ + DECLARE y BIGINT; \ + BEGIN \ + y := x / -1; \ + END $$", + ) + .await; + + assert!( + result.is_err(), + "i64::MIN / -1 in procedural block should error, not panic" + ); +} + +/// Float divide by negative zero should return an error or NULL, not ±Inf. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn procedural_eval_float_div_neg_zero() { + let server = TestServer::start().await; + + server + .exec("CREATE COLLECTION fdiv TYPE DOCUMENT STRICT (id TEXT PRIMARY KEY, v FLOAT)") + .await + .unwrap(); + + // -0.0 passes the `!= 0.0` guard in f64 comparison but produces -inf. + let result = server + .exec( + "DO $$ \ + DECLARE a FLOAT := 1.0; \ + DECLARE b FLOAT := -0.0; \ + DECLARE c FLOAT; \ + BEGIN \ + c := a / b; \ + INSERT INTO fdiv (id, v) VALUES ('k', c); \ + END $$", + ) + .await; + + // Either error, or if it succeeds, the stored value must not be infinity. + if result.is_ok() { + let rows = server + .query_text("SELECT v FROM fdiv WHERE id = 'k'") + .await + .unwrap(); + if let Some(val) = rows.first() { + assert!( + !val.to_lowercase().contains("inf"), + "float / -0.0 should not produce infinity: got {val}" + ); + } + } +} diff --git a/nodedb/tests/sql_default_expressions.rs b/nodedb/tests/sql_default_expressions.rs new file mode 100644 index 00000000..4091f32f --- /dev/null +++ b/nodedb/tests/sql_default_expressions.rs @@ -0,0 +1,170 @@ +//! Integration coverage for DEFAULT expression evaluation in INSERT. +//! +//! The planner's `evaluate_default_expr` recognizes only a fixed keyword list +//! (UUID_V7, NOW(), NANOID, literals). Any other expression returns None, +//! causing the column to be silently omitted. These tests verify that +//! expression-based defaults are evaluated, not dropped. + +mod common; + +use common::pgwire_harness::TestServer; + +/// `DEFAULT upper('x')` — a scalar function call as a default value. +/// The planner should evaluate this rather than dropping the column. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn default_scalar_function_upper() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION def_fn TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + a TEXT DEFAULT upper('x'))", + ) + .await + .unwrap(); + + server + .exec("INSERT INTO def_fn (id) VALUES ('k1')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT a FROM def_fn WHERE id = 'k1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1, "row should exist"); + // The default should produce 'X'. If the column was silently dropped, + // the value will be null/absent. + assert!( + rows[0].contains('X'), + "DEFAULT upper('x') should produce 'X', got {:?}", + rows[0] + ); +} + +/// `DEFAULT lower('HELLO')` — another scalar function. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn default_scalar_function_lower() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION def_lower TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + tag TEXT DEFAULT lower('HELLO'))", + ) + .await + .unwrap(); + + server + .exec("INSERT INTO def_lower (id) VALUES ('k1')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT tag FROM def_lower WHERE id = 'k1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1); + assert!( + rows[0].contains("hello"), + "DEFAULT lower('HELLO') should produce 'hello', got {:?}", + rows[0] + ); +} + +/// `DEFAULT 1 + 2` — a binary arithmetic expression as default. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn default_arithmetic_expression() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION def_arith TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + v INT DEFAULT 1 + 2)", + ) + .await + .unwrap(); + + server + .exec("INSERT INTO def_arith (id) VALUES ('k1')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT v FROM def_arith WHERE id = 'k1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1); + assert!( + rows[0].contains('3'), + "DEFAULT 1 + 2 should produce 3, got {:?}", + rows[0] + ); +} + +/// `DEFAULT concat('a', 'b')` — a multi-arg function. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn default_concat_function() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION def_concat TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + label TEXT DEFAULT concat('hello', '_', 'world'))", + ) + .await + .unwrap(); + + server + .exec("INSERT INTO def_concat (id) VALUES ('k1')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT label FROM def_concat WHERE id = 'k1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1); + assert!( + rows[0].contains("hello_world"), + "DEFAULT concat should produce 'hello_world', got {:?}", + rows[0] + ); +} + +/// Verify that recognized defaults (literal string, NOW(), UUID_V7) still work. +/// This is a baseline — not a new bug, just ensures we don't regress. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn default_recognized_expressions_still_work() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION def_known TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + status TEXT DEFAULT 'active', \ + uid TEXT DEFAULT UUID_V7)", + ) + .await + .unwrap(); + + server + .exec("INSERT INTO def_known (id) VALUES ('k1')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT status FROM def_known WHERE id = 'k1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1); + assert!( + rows[0].contains("active"), + "DEFAULT 'active' should work: got {:?}", + rows[0] + ); +} diff --git a/nodedb/tests/sql_join_correctness.rs b/nodedb/tests/sql_join_correctness.rs new file mode 100644 index 00000000..7dd3a86a --- /dev/null +++ b/nodedb/tests/sql_join_correctness.rs @@ -0,0 +1,338 @@ +//! Integration coverage for SQL JOIN correctness. +//! +//! Covers: RIGHT JOIN plan conversion (inline swap), multi-predicate join +//! conditions (non-equi predicates), and NATURAL JOIN handling. + +mod common; + +use common::pgwire_harness::TestServer; + +// --------------------------------------------------------------------------- +// Helper: set up three related tables for join tests +// --------------------------------------------------------------------------- + +async fn setup_join_tables(server: &TestServer) { + server + .exec( + "CREATE COLLECTION j_t1 TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + name TEXT, \ + x INT)", + ) + .await + .unwrap(); + server + .exec( + "CREATE COLLECTION j_t2 TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + t1_id TEXT, \ + y INT)", + ) + .await + .unwrap(); + server + .exec( + "CREATE COLLECTION j_t3 TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + t2_id TEXT, \ + z INT)", + ) + .await + .unwrap(); + + // t1 data + server + .exec("INSERT INTO j_t1 (id, name, x) VALUES ('a', 'Alice', 10)") + .await + .unwrap(); + server + .exec("INSERT INTO j_t1 (id, name, x) VALUES ('b', 'Bob', 20)") + .await + .unwrap(); + + // t2 data — references t1 + server + .exec("INSERT INTO j_t2 (id, t1_id, y) VALUES ('p', 'a', 100)") + .await + .unwrap(); + server + .exec("INSERT INTO j_t2 (id, t1_id, y) VALUES ('q', 'a', 200)") + .await + .unwrap(); + + // t3 data — references t2, plus one unmatched + server + .exec("INSERT INTO j_t3 (id, t2_id, z) VALUES ('x', 'p', 1)") + .await + .unwrap(); + server + .exec("INSERT INTO j_t3 (id, t2_id, z) VALUES ('y', 'q', 2)") + .await + .unwrap(); + server + .exec("INSERT INTO j_t3 (id, t2_id, z) VALUES ('z', 'NONE', 3)") + .await + .unwrap(); +} + +// --------------------------------------------------------------------------- +// RIGHT JOIN with nested join on the left +// --------------------------------------------------------------------------- + +/// A RIGHT JOIN where the left side is itself a join. The planner rewrites +/// RIGHT → LEFT by swapping sides, but must also swap inline_left/inline_right. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn right_join_with_nested_left_join_returns_correct_rows() { + let server = TestServer::start().await; + setup_join_tables(&server).await; + + // t3 has 3 rows. RIGHT JOIN means all t3 rows appear, with NULLs for + // unmatched left-side (t1 JOIN t2) rows. + let rows = server + .query_text( + "SELECT j_t3.id FROM j_t1 \ + INNER JOIN j_t2 ON j_t1.id = j_t2.t1_id \ + RIGHT JOIN j_t3 ON j_t2.id = j_t3.t2_id", + ) + .await + .unwrap(); + + // All 3 t3 rows must appear (x, y, z). z has no match in the left join + // but RIGHT JOIN preserves it. + assert_eq!( + rows.len(), + 3, + "RIGHT JOIN should preserve all right-side rows: got {rows:?}" + ); +} + +/// Simple RIGHT JOIN (no nested join) as a baseline. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn simple_right_join_preserves_right_rows() { + let server = TestServer::start().await; + setup_join_tables(&server).await; + + let rows = server + .query_text( + "SELECT j_t3.id FROM j_t2 \ + RIGHT JOIN j_t3 ON j_t2.id = j_t3.t2_id", + ) + .await + .unwrap(); + + assert_eq!( + rows.len(), + 3, + "simple RIGHT JOIN should return all 3 t3 rows: got {rows:?}" + ); +} + +// --------------------------------------------------------------------------- +// Multiple non-equi predicates in JOIN ON +// --------------------------------------------------------------------------- + +/// All non-equi predicates in a JOIN ON clause must be preserved. +/// Only keeping the first one silently drops filter conditions. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn join_preserves_all_non_equi_predicates() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION jn_left TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + x INT, \ + y INT)", + ) + .await + .unwrap(); + server + .exec( + "CREATE COLLECTION jn_right TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + x INT, \ + y INT)", + ) + .await + .unwrap(); + + server + .exec("INSERT INTO jn_left (id, x, y) VALUES ('L1', 10, 5)") + .await + .unwrap(); + server + .exec("INSERT INTO jn_left (id, x, y) VALUES ('L2', 10, 50)") + .await + .unwrap(); + + server + .exec("INSERT INTO jn_right (id, x, y) VALUES ('R1', 10, 20)") + .await + .unwrap(); + + // Both non-equi conditions must apply: + // L1: x=10 matches R1.x=10, L1.x(10) > R1.x(10) is FALSE → no match + // Actually let me redesign: equi on id won't work. Let's use a simpler setup. + // L1: x=10, y=5. R1: x=10, y=20. + // Condition: jn_left.x = jn_right.x AND jn_left.y < jn_right.y AND jn_left.y > 0 + // L1 matches: x=x, 5 < 20, 5 > 0 → yes + // L2 matches: x=x, 50 < 20 is FALSE → no + + let rows = server + .query_text( + "SELECT jn_left.id FROM jn_left \ + JOIN jn_right ON jn_left.x = jn_right.x \ + AND jn_left.y < jn_right.y \ + AND jn_left.y > 0", + ) + .await + .unwrap(); + + assert_eq!( + rows.len(), + 1, + "only L1 should match all three join conditions: got {rows:?}" + ); + assert!( + rows[0].contains("L1"), + "matched row should be L1: got {:?}", + rows[0] + ); +} + +/// If the second non-equi predicate is dropped, L2 would incorrectly appear. +/// This is a regression guard: assert L2 is NOT in the result. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn join_does_not_drop_second_non_equi_predicate() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION jd_a TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + val INT, \ + lo INT, \ + hi INT)", + ) + .await + .unwrap(); + server + .exec( + "CREATE COLLECTION jd_b TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + val INT, \ + bound INT)", + ) + .await + .unwrap(); + + // A: val=1, lo=0, hi=100. B: val=1, bound=50. + // Condition: a.val = b.val AND a.lo < b.bound AND a.hi > b.bound + // → 0 < 50 AND 100 > 50 → match. + server + .exec("INSERT INTO jd_a (id, val, lo, hi) VALUES ('a1', 1, 0, 100)") + .await + .unwrap(); + // A2: val=1, lo=0, hi=10. + // → 0 < 50 AND 10 > 50 → no match (second non-equi fails). + server + .exec("INSERT INTO jd_a (id, val, lo, hi) VALUES ('a2', 1, 0, 10)") + .await + .unwrap(); + + server + .exec("INSERT INTO jd_b (id, val, bound) VALUES ('b1', 1, 50)") + .await + .unwrap(); + + let rows = server + .query_text( + "SELECT jd_a.id FROM jd_a \ + JOIN jd_b ON jd_a.val = jd_b.val \ + AND jd_a.lo < jd_b.bound \ + AND jd_a.hi > jd_b.bound", + ) + .await + .unwrap(); + + assert_eq!(rows.len(), 1, "only a1 should match: got {rows:?}"); + assert!( + rows[0].contains("a1"), + "matched row should be a1: got {:?}", + rows[0] + ); + // Regression guard: if the second non-equi (`hi > bound`) was dropped, + // a2 would also appear. + assert!( + !rows.iter().any(|r| r.contains("a2")), + "a2 should NOT match — hi=10 is not > bound=50: got {rows:?}" + ); +} + +// --------------------------------------------------------------------------- +// NATURAL JOIN +// --------------------------------------------------------------------------- + +/// NATURAL JOIN should either compute the shared column set and produce a +/// correct equi-join, or return an explicit unsupported error. It must NOT +/// silently produce a cartesian product. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn natural_join_is_not_cartesian() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION nat_a TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + shared_col TEXT)", + ) + .await + .unwrap(); + server + .exec( + "CREATE COLLECTION nat_b TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + shared_col TEXT)", + ) + .await + .unwrap(); + + server + .exec("INSERT INTO nat_a (id, shared_col) VALUES ('a1', 'x')") + .await + .unwrap(); + server + .exec("INSERT INTO nat_a (id, shared_col) VALUES ('a2', 'y')") + .await + .unwrap(); + + server + .exec("INSERT INTO nat_b (id, shared_col) VALUES ('b1', 'x')") + .await + .unwrap(); + + let result = server + .query_text("SELECT nat_a.id FROM nat_a NATURAL JOIN nat_b") + .await; + + match result { + Ok(rows) => { + // If NATURAL JOIN is supported, shared_col='x' should produce + // a match on a1↔b1 only. Cartesian would give 2 rows. + assert!( + rows.len() <= 1, + "NATURAL JOIN should match on shared_col, not produce cartesian ({} rows): {rows:?}", + rows.len() + ); + } + Err(msg) => { + // An explicit "unsupported" error is acceptable. + assert!( + msg.to_lowercase().contains("natural") + || msg.to_lowercase().contains("unsupported") + || msg.to_lowercase().contains("not supported"), + "error should mention NATURAL JOIN: {msg}" + ); + } + } +} diff --git a/nodedb/tests/sql_parser_string_handling.rs b/nodedb/tests/sql_parser_string_handling.rs new file mode 100644 index 00000000..17728ee6 --- /dev/null +++ b/nodedb/tests/sql_parser_string_handling.rs @@ -0,0 +1,279 @@ +//! Integration coverage for SQL parsers handling string literals correctly. +//! +//! Hand-rolled parsers must skip over single-quoted string literals when +//! scanning for structural delimiters (`BEGIN`, `)`, `,`, `WITH`, `{}`). +//! These tests verify that embedded keywords and special characters inside +//! string values do not corrupt the parse. + +mod common; + +use common::pgwire_harness::TestServer; + +// --------------------------------------------------------------------------- +// 1. CREATE TRIGGER: `BEGIN` inside a string literal in WHEN clause +// --------------------------------------------------------------------------- + +/// A WHEN clause containing the word 'BEGIN' inside a string literal must not +/// confuse the header/body split. The trigger should parse correctly and the +/// body should execute as written. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn trigger_begin_inside_string_literal_does_not_corrupt_parse() { + let server = TestServer::start().await; + + server + .exec("CREATE COLLECTION items TYPE DOCUMENT STRICT (id TEXT PRIMARY KEY, label TEXT)") + .await + .unwrap(); + + server.exec("CREATE COLLECTION audit_log").await.unwrap(); + + // The WHEN clause references the string 'BEGIN' — the parser must not + // treat this as the body start. + let result = server + .exec( + "CREATE TRIGGER tr_begin BEFORE INSERT ON items FOR EACH ROW \ + WHEN (NEW.label = 'BEGIN' OR NEW.label = 'open') \ + BEGIN INSERT INTO audit_log { id: NEW.id, note: 'triggered' }; END", + ) + .await; + assert!( + result.is_ok(), + "CREATE TRIGGER with 'BEGIN' in WHEN string should succeed: {:?}", + result + ); + + // The trigger should not interfere with a normal insert. + server + .exec("INSERT INTO items (id, label) VALUES ('i1', 'hello')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT id FROM items WHERE id = 'i1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1, "insert should succeed through the trigger"); +} + +// --------------------------------------------------------------------------- +// 2. INSERT: `)` inside a string literal in VALUES +// --------------------------------------------------------------------------- + +/// A string value containing `)` must not break the VALUES pre-parse that +/// scans for the closing paren. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn insert_paren_inside_string_value_parses_correctly() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION msgs TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + body TEXT)", + ) + .await + .unwrap(); + + server + .exec("INSERT INTO msgs (id, body) VALUES ('m1', 'hello)world')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT body FROM msgs WHERE id = 'm1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1, "insert with ) in string should succeed"); + assert!( + rows[0].contains("hello)world"), + "value should preserve the literal paren: got {:?}", + rows[0] + ); +} + +/// Multiple special characters inside string values in a multi-column INSERT. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn insert_multiple_parens_in_strings() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION notes TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + a TEXT, \ + b TEXT)", + ) + .await + .unwrap(); + + server + .exec("INSERT INTO notes (id, a, b) VALUES ('n1', 'a)b', 'c(d')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT a FROM notes WHERE id = 'n1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1); + assert!(rows[0].contains("a)b"), "got {:?}", rows[0]); +} + +// --------------------------------------------------------------------------- +// 3. split_values: quote tracking inside brackets/arrays +// --------------------------------------------------------------------------- + +/// An array literal containing strings with `)` must not confuse the value +/// splitter's bracket-depth tracking. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn insert_array_with_paren_in_string_element() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION arr_test TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + tags TEXT, \ + count INT)", + ) + .await + .unwrap(); + + // ARRAY['x)y', 'z'] — the `)` inside the first element must not + // decrement the bracket counter. + server + .exec("INSERT INTO arr_test (id, tags, count) VALUES ('a1', ARRAY['x)y', 'z'], 42)") + .await + .unwrap(); + + let rows = server + .query_text("SELECT count FROM arr_test WHERE id = 'a1'") + .await + .unwrap(); + assert_eq!( + rows.len(), + 1, + "insert with array containing ')' should succeed" + ); + assert!( + rows[0].contains("42"), + "second column value should be 42, got {:?}", + rows[0] + ); +} + +// --------------------------------------------------------------------------- +// 5. Object-literal `{ }` rewriter: `''`-escaped quotes +// --------------------------------------------------------------------------- + +/// The `{ }` preprocessor must handle SQL-escaped single quotes (`''`). +/// `'it''s'` is a valid SQL string containing an apostrophe. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn object_literal_with_escaped_quote_parses_correctly() { + let server = TestServer::start().await; + + server.exec("CREATE COLLECTION esc_test").await.unwrap(); + + // The `{ note: 'it''s' }` syntax goes through `find_matching_brace` + // which must correctly handle the `''` escape. + server + .exec("INSERT INTO esc_test { id: 'e1', note: 'it''s fine' }") + .await + .unwrap(); + + let rows = server + .query_text("SELECT * FROM esc_test WHERE id = 'e1'") + .await + .unwrap(); + assert_eq!( + rows.len(), + 1, + "insert with escaped quote in object literal should succeed" + ); + assert!( + rows[0].contains("it's fine") || rows[0].contains("it''s fine"), + "value should contain the apostrophe: got {:?}", + rows[0] + ); +} + +/// Escaped quotes adjacent to the closing brace. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn object_literal_escaped_quote_near_brace() { + let server = TestServer::start().await; + + server.exec("CREATE COLLECTION esc_test2").await.unwrap(); + + server + .exec("INSERT INTO esc_test2 { id: 'e2', val: 'end''s}' }") + .await + .unwrap(); + + let rows = server + .query_text("SELECT * FROM esc_test2 WHERE id = 'e2'") + .await + .unwrap(); + assert_eq!( + rows.len(), + 1, + "escaped quote near brace should parse correctly" + ); +} + +// --------------------------------------------------------------------------- +// 7. CREATE MATERIALIZED VIEW: `WITH` inside SELECT body +// --------------------------------------------------------------------------- + +/// A materialized view whose SELECT contains a CTE (`WITH cte AS (...)`) +/// must not have the CTE keyword mistaken for the options clause. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn materialized_view_with_cte_in_select() { + let server = TestServer::start().await; + + server + .exec("CREATE COLLECTION mv_src TYPE DOCUMENT STRICT (id TEXT PRIMARY KEY, val INT)") + .await + .unwrap(); + server + .exec("INSERT INTO mv_src (id, val) VALUES ('r1', 10)") + .await + .unwrap(); + + // The SELECT uses a CTE — `WITH s AS (...)`. The parser must not + // truncate the query at the CTE's `WITH`. + let result = server + .exec( + "CREATE MATERIALIZED VIEW mv_cte ON mv_src AS \ + WITH s AS (SELECT id, val FROM mv_src) SELECT id, val FROM s", + ) + .await; + assert!( + result.is_ok(), + "CREATE MATERIALIZED VIEW with CTE should succeed: {:?}", + result + ); +} + +/// A materialized view whose SELECT contains a column aliased with `with_`. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn materialized_view_with_keyword_in_column_alias() { + let server = TestServer::start().await; + + server + .exec("CREATE COLLECTION mv_src2 TYPE DOCUMENT STRICT (id TEXT PRIMARY KEY, val INT)") + .await + .unwrap(); + + // `with_tax` as a column alias should not be mistaken for `WITH` options. + let result = server + .exec( + "CREATE MATERIALIZED VIEW mv_alias ON mv_src2 AS \ + SELECT id, val AS with_tax FROM mv_src2", + ) + .await; + assert!( + result.is_ok(), + "column alias starting with 'with_' should not break parse: {:?}", + result + ); +} diff --git a/nodedb/tests/sql_procedure_cache_safety.rs b/nodedb/tests/sql_procedure_cache_safety.rs new file mode 100644 index 00000000..6461e174 --- /dev/null +++ b/nodedb/tests/sql_procedure_cache_safety.rs @@ -0,0 +1,106 @@ +//! Integration coverage for procedure/trigger body cache correctness. +//! +//! The ProcedureBlockCache keys on a 64-bit hash of the body SQL without +//! storing the source for equality verification. A hash collision returns +//! the wrong compiled block. This test verifies that two different trigger +//! bodies always execute their own logic, not each other's. + +mod common; + +use std::time::Duration; + +use common::pgwire_harness::TestServer; + +/// Two triggers with different bodies on different collections must each +/// execute their own body, never the other's. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn different_trigger_bodies_execute_independently() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION cache_a TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + val TEXT DEFAULT 'none')", + ) + .await + .unwrap(); + server + .exec( + "CREATE COLLECTION cache_b TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + val TEXT DEFAULT 'none')", + ) + .await + .unwrap(); + // Audit log to capture trigger side effects. + server.exec("CREATE COLLECTION trigger_log").await.unwrap(); + + // Trigger A: logs 'trigger_a_fired'. + server + .exec( + "CREATE TRIGGER trg_a AFTER INSERT ON cache_a FOR EACH ROW \ + BEGIN \ + INSERT INTO trigger_log { id: 'log_a', source: 'trigger_a_fired' }; \ + END", + ) + .await + .unwrap(); + + // Trigger B: logs 'trigger_b_fired' — different body. + server + .exec( + "CREATE TRIGGER trg_b AFTER INSERT ON cache_b FOR EACH ROW \ + BEGIN \ + INSERT INTO trigger_log { id: 'log_b', source: 'trigger_b_fired' }; \ + END", + ) + .await + .unwrap(); + + // Fire trigger A. + server + .exec("INSERT INTO cache_a (id) VALUES ('a1')") + .await + .unwrap(); + // Wait for async AFTER trigger dispatch. + tokio::time::sleep(Duration::from_millis(500)).await; + + // Fire trigger B. + server + .exec("INSERT INTO cache_b (id) VALUES ('b1')") + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(500)).await; + + // Verify trigger A logged 'trigger_a_fired', not 'trigger_b_fired'. + let log_a = server + .query_text("SELECT * FROM trigger_log WHERE id = 'log_a'") + .await + .unwrap(); + assert_eq!(log_a.len(), 1, "trigger A should have logged: {log_a:?}"); + assert!( + log_a[0].contains("trigger_a_fired"), + "trigger A should execute its own body, not B's: {:?}", + log_a[0] + ); + // Regression guard: if cache collision occurred, A's log entry would + // contain B's message. + assert!( + !log_a[0].contains("trigger_b_fired"), + "trigger A must not execute trigger B's body (cache collision): {:?}", + log_a[0] + ); + + // Verify trigger B logged its own message. + let log_b = server + .query_text("SELECT * FROM trigger_log WHERE id = 'log_b'") + .await + .unwrap(); + assert_eq!(log_b.len(), 1, "trigger B should have logged: {log_b:?}"); + assert!( + log_b[0].contains("trigger_b_fired"), + "trigger B should execute its own body: {:?}", + log_b[0] + ); +} diff --git a/nodedb/tests/sql_recursive_cte.rs b/nodedb/tests/sql_recursive_cte.rs new file mode 100644 index 00000000..b9ee387b --- /dev/null +++ b/nodedb/tests/sql_recursive_cte.rs @@ -0,0 +1,171 @@ +//! Integration coverage for WITH RECURSIVE (recursive CTEs). +//! +//! The planner must preserve the recursive structure (base + recursive branch +//! with self-reference), respect UNION vs UNION ALL, and allow configurable +//! max iteration depth. + +mod common; + +use common::pgwire_harness::TestServer; + +/// Basic recursive CTE: generate numbers 1..5 using UNION ALL. +/// The recursive branch references the CTE and increments. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn recursive_cte_generates_sequence() { + let server = TestServer::start().await; + + // We need a base collection for the query to target. + server + .exec("CREATE COLLECTION cte_dummy TYPE DOCUMENT STRICT (id TEXT PRIMARY KEY)") + .await + .unwrap(); + + let rows = server + .query_text( + "WITH RECURSIVE c(n) AS (\ + SELECT 1 \ + UNION ALL \ + SELECT n + 1 FROM c WHERE n < 5\ + ) \ + SELECT n FROM c", + ) + .await; + + match rows { + Ok(values) => { + assert_eq!( + values.len(), + 5, + "recursive CTE should produce 5 rows (1..5): got {values:?}" + ); + // Values should be 1, 2, 3, 4, 5 in some order. + let nums: Vec = values + .iter() + .filter_map(|v| v.trim().parse().ok()) + .collect(); + let mut sorted = nums.clone(); + sorted.sort(); + assert_eq!( + sorted, + vec![1, 2, 3, 4, 5], + "recursive CTE should produce 1..5: got {nums:?}" + ); + } + Err(msg) => { + // If recursive CTEs are not yet supported, an explicit error is acceptable + // but a silent empty result or wrong collection is not. + assert!( + msg.to_lowercase().contains("recursive") + || msg.to_lowercase().contains("not supported") + || msg.to_lowercase().contains("unsupported"), + "error should mention recursive CTE: {msg}" + ); + } + } +} + +/// UNION ALL in recursive CTE must preserve duplicates. +/// UNION (without ALL) should deduplicate. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn recursive_cte_union_all_preserves_duplicates() { + let server = TestServer::start().await; + + server + .exec("CREATE COLLECTION cte_dup TYPE DOCUMENT STRICT (id TEXT PRIMARY KEY, grp TEXT)") + .await + .unwrap(); + + // Base: two rows with the same grp value. Recursive: no recursion (WHERE false). + // UNION ALL should keep both base rows; UNION would deduplicate. + let result_all = server + .query_text( + "WITH RECURSIVE c(val) AS (\ + SELECT 1 UNION ALL SELECT 1\ + ) \ + SELECT val FROM c", + ) + .await; + + match result_all { + Ok(rows) => { + assert_eq!( + rows.len(), + 2, + "UNION ALL should preserve duplicate base rows: got {rows:?}" + ); + } + Err(_) => { + // Acceptable if recursive CTEs aren't supported yet — but the + // test failing here captures the coverage gap. + } + } +} + +/// Recursive CTE over actual table data: tree traversal. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn recursive_cte_tree_traversal() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION tree TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + parent_id TEXT)", + ) + .await + .unwrap(); + + // Build a small tree: root → child1 → grandchild + server + .exec("INSERT INTO tree (id, parent_id) VALUES ('root', NULL)") + .await + .unwrap(); + server + .exec("INSERT INTO tree (id, parent_id) VALUES ('child1', 'root')") + .await + .unwrap(); + server + .exec("INSERT INTO tree (id, parent_id) VALUES ('grandchild', 'child1')") + .await + .unwrap(); + server + .exec("INSERT INTO tree (id, parent_id) VALUES ('orphan', 'missing')") + .await + .unwrap(); + + // Traverse descendants of 'root'. + let rows = server + .query_text( + "WITH RECURSIVE descendants(id) AS (\ + SELECT id FROM tree WHERE id = 'root' \ + UNION ALL \ + SELECT t.id FROM tree t \ + INNER JOIN descendants d ON t.parent_id = d.id\ + ) \ + SELECT id FROM descendants", + ) + .await; + + match rows { + Ok(values) => { + // Should find: root, child1, grandchild (3 rows). NOT orphan. + assert_eq!( + values.len(), + 3, + "tree traversal should find root + child1 + grandchild: got {values:?}" + ); + assert!( + !values.iter().any(|v| v.contains("orphan")), + "orphan should not appear in descendants of root: {values:?}" + ); + } + Err(msg) => { + // An explicit unsupported error is acceptable — the test captures the gap. + assert!( + msg.to_lowercase().contains("recursive") + || msg.to_lowercase().contains("not supported"), + "unexpected error: {msg}" + ); + } + } +} diff --git a/nodedb/tests/sql_rls_predicate_parse.rs b/nodedb/tests/sql_rls_predicate_parse.rs new file mode 100644 index 00000000..0361073d --- /dev/null +++ b/nodedb/tests/sql_rls_predicate_parse.rs @@ -0,0 +1,120 @@ +//! Integration coverage for RLS policy predicate parsing. +//! +//! The USING clause predicate must preserve balanced parentheses faithfully. +//! `trim_matches('(' | ')')` strips all leading/trailing parens — not just +//! one matched pair — corrupting nested boolean expressions. + +mod common; + +use common::pgwire_harness::TestServer; + +/// Double-wrapped balanced parens `((x > 0) AND (y = 1))` must be stored +/// correctly and enforced at query time. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn rls_policy_nested_parens_preserved() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION rls_items TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + x INT, \ + y INT)", + ) + .await + .unwrap(); + + // Create a policy with nested parens in USING predicate. + let result = server + .exec( + "CREATE RLS POLICY pos_check ON rls_items FOR READ \ + USING ((x > 0) AND (y = 1))", + ) + .await; + assert!( + result.is_ok(), + "CREATE RLS POLICY with nested parens should succeed: {:?}", + result + ); + + // Insert rows: one matching, one not. + server + .exec("INSERT INTO rls_items (id, x, y) VALUES ('ok', 5, 1)") + .await + .unwrap(); + server + .exec("INSERT INTO rls_items (id, x, y) VALUES ('bad', -1, 1)") + .await + .unwrap(); + + // Create a non-superuser to test RLS enforcement. + server + .exec("CREATE USER rls_reader WITH PASSWORD 'pass' ROLE readonly") + .await + .unwrap(); + + let (reader, _handle) = server.connect_as("rls_reader", "pass").await.unwrap(); + let result = reader.simple_query("SELECT id FROM rls_items").await; + match result { + Ok(msgs) => { + let mut ids: Vec = Vec::new(); + for msg in &msgs { + if let tokio_postgres::SimpleQueryMessage::Row(row) = msg { + ids.push(row.get(0).unwrap_or("").to_string()); + } + } + // The policy should allow `ok` (x=5 > 0, y=1) and reject `bad` (x=-1). + assert!( + ids.iter().any(|id| id.contains("ok")), + "row with x=5,y=1 should be visible under RLS: {ids:?}" + ); + assert!( + !ids.iter().any(|id| id.contains("bad")), + "row with x=-1 should be filtered by RLS: {ids:?}" + ); + } + Err(e) => { + // Extract detailed error if available. + let detail = if let Some(db_err) = e.as_db_error() { + format!( + "{}: {} (SQLSTATE {})", + db_err.severity(), + db_err.message(), + db_err.code().code() + ) + } else { + format!("{e:?}") + }; + panic!("RLS query should not error — predicate may be corrupted: {detail}"); + } + } +} + +/// Triple-nested parens should not be over-stripped. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn rls_policy_triple_nested_parens() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION rls_deep TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + a INT, \ + b INT, \ + c INT)", + ) + .await + .unwrap(); + + let result = server + .exec( + "CREATE RLS POLICY deep_check ON rls_deep FOR READ \ + USING (((a > 0) AND (b > 0)) OR (c = 99))", + ) + .await; + assert!( + result.is_ok(), + "deeply nested RLS predicate should parse without corruption: {:?}", + result + ); +} diff --git a/nodedb/tests/sql_trigger_fuel.rs b/nodedb/tests/sql_trigger_fuel.rs new file mode 100644 index 00000000..0278a38a --- /dev/null +++ b/nodedb/tests/sql_trigger_fuel.rs @@ -0,0 +1,98 @@ +//! Integration coverage for trigger execution fuel / budget limits. +//! +//! Trigger bodies run user-supplied code. An infinite loop must not pin a +//! Control Plane worker for the full 3600-second deadline. The execution +//! budget must cap iterations and wall-clock time to a reasonable bound. + +mod common; + +use std::time::{Duration, Instant}; + +use common::pgwire_harness::TestServer; + +/// An infinite-loop trigger body must be terminated by the execution budget +/// within a reasonable time (< 30 seconds), not the 1-hour wall clock. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn infinite_loop_trigger_terminates_within_budget() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION fuel_test TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + v INT)", + ) + .await + .unwrap(); + + // Create a trigger with an infinite loop body. + server + .exec( + "CREATE TRIGGER fuel_loop AFTER INSERT ON fuel_test FOR EACH ROW \ + BEGIN LOOP END LOOP; END", + ) + .await + .unwrap(); + + let start = Instant::now(); + let result = server + .exec("INSERT INTO fuel_test (id, v) VALUES ('a', 1)") + .await; + let elapsed = start.elapsed(); + + // The trigger should either: + // (a) error because the budget was exhausted, or + // (b) succeed but complete within a sane time bound. + // It must NOT hang for 3600 seconds. + assert!( + elapsed < Duration::from_secs(30), + "infinite-loop trigger took {elapsed:?} — budget should cap execution well under 1 hour" + ); + + if let Err(msg) = result { + assert!( + msg.to_lowercase().contains("fuel") + || msg.to_lowercase().contains("budget") + || msg.to_lowercase().contains("timeout") + || msg.to_lowercase().contains("iteration") + || msg.to_lowercase().contains("limit"), + "error should mention budget/fuel/timeout: {msg}" + ); + } +} + +/// A trigger with a bounded but expensive loop should succeed if within budget. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn bounded_loop_trigger_completes_normally() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION fuel_ok TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + counter INT DEFAULT 0)", + ) + .await + .unwrap(); + + // A trigger that loops 10 times — well within any reasonable budget. + server + .exec( + "CREATE TRIGGER fuel_bounded AFTER INSERT ON fuel_ok FOR EACH ROW \ + BEGIN \ + DECLARE i INT := 0; \ + WHILE i < 10 LOOP \ + i := i + 1; \ + END LOOP; \ + END", + ) + .await + .unwrap(); + + let result = server.exec("INSERT INTO fuel_ok (id) VALUES ('b1')").await; + assert!( + result.is_ok(), + "bounded loop trigger should complete normally: {:?}", + result + ); +} diff --git a/nodedb/tests/sql_utf8_expressions.rs b/nodedb/tests/sql_utf8_expressions.rs new file mode 100644 index 00000000..9c7bd08f --- /dev/null +++ b/nodedb/tests/sql_utf8_expressions.rs @@ -0,0 +1,145 @@ +//! Integration coverage for UTF-8 correctness in expression parsing. +//! +//! The expression tokenizer in `nodedb-query` iterates byte-by-byte but +//! slices `&str` — panicking on multi-byte UTF-8 codepoints. These tests +//! verify that non-ASCII characters in generated columns, check constraints, +//! and typeguard expressions do not cause panics or data corruption. + +mod common; + +use common::pgwire_harness::TestServer; + +/// A GENERATED ALWAYS AS expression containing a CJK literal must not panic. +/// The tokenizer slices `&input[i..i+2]` which crosses a char boundary on +/// 3-byte UTF-8 codepoints. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn generated_column_with_cjk_literal_does_not_panic() { + let server = TestServer::start().await; + + let result = server + .exec( + "CREATE COLLECTION utf_gen TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + x TEXT, \ + y TEXT GENERATED ALWAYS AS ('你' || x))", + ) + .await; + assert!( + result.is_ok(), + "CJK literal in GENERATED expression should not panic: {:?}", + result + ); + + server + .exec("INSERT INTO utf_gen (id, x) VALUES ('u1', 'hello')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT y FROM utf_gen WHERE id = 'u1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1); + assert!( + rows[0].contains("你hello"), + "generated column should concatenate CJK prefix: got {:?}", + rows[0] + ); +} + +/// Emoji (4-byte UTF-8) in an expression must not panic. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn generated_column_with_emoji_does_not_panic() { + let server = TestServer::start().await; + + let result = server + .exec( + "CREATE COLLECTION utf_emoji TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + tag TEXT, \ + display TEXT GENERATED ALWAYS AS ('🎉' || tag))", + ) + .await; + assert!( + result.is_ok(), + "emoji in GENERATED expression should not panic: {:?}", + result + ); +} + +/// Multi-byte characters followed by operators (`<=`, `>=`, `!=`, `||`) +/// are the specific trigger for the `&input[i..i+2]` boundary panic. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn check_constraint_with_utf8_near_operator() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION utf_ck TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + name TEXT)", + ) + .await + .unwrap(); + + // Check constraint with a multi-byte literal adjacent to `!=` + let result = server + .exec( + "ALTER COLLECTION utf_ck ADD CONSTRAINT no_forbidden \ + CHECK (NEW.name != '禁止')", + ) + .await; + assert!( + result.is_ok(), + "CHECK with UTF-8 literal near != operator should not panic: {:?}", + result + ); + + // Valid insert should pass the constraint. + server + .exec("INSERT INTO utf_ck (id, name) VALUES ('c1', 'allowed')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT id FROM utf_ck WHERE id = 'c1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1); +} + +/// Latin diacritics (2-byte UTF-8) in expression strings. +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn generated_column_with_latin_diacritics() { + let server = TestServer::start().await; + + let result = server + .exec( + "CREATE COLLECTION utf_lat TYPE DOCUMENT STRICT (\ + id TEXT PRIMARY KEY, \ + city TEXT, \ + greeting TEXT GENERATED ALWAYS AS ('café in ' || city))", + ) + .await; + assert!( + result.is_ok(), + "Latin diacritics in GENERATED expression should not panic: {:?}", + result + ); + + server + .exec("INSERT INTO utf_lat (id, city) VALUES ('l1', 'Paris')") + .await + .unwrap(); + + let rows = server + .query_text("SELECT greeting FROM utf_lat WHERE id = 'l1'") + .await + .unwrap(); + assert_eq!(rows.len(), 1); + assert!( + rows[0].contains("café in Paris"), + "generated column should preserve diacritics: got {:?}", + rows[0] + ); +}