From d3620e2852c40fe9c894eb22ab29a1a613449afd Mon Sep 17 00:00:00 2001 From: harehare Date: Thu, 25 Jun 2026 22:08:34 +0900 Subject: [PATCH] feat(sql): expand function library and add zone-map document pruning Add a comprehensive set of SQL scalar/aggregate functions (string, numeric, null-handling, CASE, group_concat/string_agg, COUNT DISTINCT) comparable to a general-purpose RDBMS. Also wire ZoneMaps into SqlEngine so a single, non-joined `SELECT ... FROM blocks` can skip whole documents based on `lang`, `depth`, or heading `content` predicates without scanning their blocks. Joins always disable this skip to avoid alias ambiguity. --- README.md | 82 ++++- src/indexes.rs | 4 + src/sql.rs | 946 +++++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 984 insertions(+), 48 deletions(-) diff --git a/README.md b/README.md index ed8b4b7..2f6e563 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ flowchart TD - **Zone Maps** — per-document statistics skip irrelevant files before scanning any blocks - **Dual query engines** — SQL via a custom `sqlparser`-based evaluator, and `mq` via `mq-lang` - **DDL support** — `CREATE TABLE`, `INSERT INTO`, `DROP TABLE` for in-memory custom tables +- **Comprehensive SQL function library** — string, numeric, null-handling, `CASE`, and aggregate functions comparable to a general-purpose RDBMS - **`mq()` scalar function** — run an mq program against Markdown content inline in SQL - **Custom page-file persistence** — 8 KB fixed pages, checksums, atomic writes - **CLI + interactive REPL + TUI** — full terminal experience @@ -384,13 +385,64 @@ SELECT id, document_id, block_type, content, pre, post, ### Built-in functions +mq-db-specific: + | Function | Description | | ------------------------------------- | ------------------------------------------ | | `under(pre, post, anc_pre, anc_post)` | O(1) interval ancestor check | | `mq(program, content)` | Run an mq program against Markdown content | | `json_extract(json, path)` | Extract a value from a JSON string | -| `count(*) / min / max / sum / avg` | Aggregate functions | -| `lower / upper / length / coalesce` | Scalar utilities | + +String: + +| Function | Description | +| --- | --- | +| `lower` / `upper` | Case conversion | +| `length` / `len` | Character count | +| `trim` / `ltrim` / `rtrim` | Strip whitespace, or the given characters | +| `concat` / `concat_ws` | Join strings (with optional separator) | +| `replace` | Replace all occurrences of a substring | +| `substring` / `substr` | Extract a substring (1-based, `FROM`/`FOR` or comma form) | +| `position` / `instr` | Find the 1-based index of a substring (0 if absent) | +| `left` / `right` | First/last `n` characters | +| `lpad` / `rpad` | Pad to a fixed length | +| `reverse` | Reverse a string | +| `repeat` | Repeat a string `n` times | +| `initcap` | Capitalize each word | +| `ascii` / `chr` | Char ↔ code point | +| `split_part` | Extract the nth delimiter-separated field | + +Numeric: + +| Function | Description | +| --- | --- | +| `abs` | Absolute value | +| `round` / `trunc` | Round / truncate, with optional decimal scale | +| `ceil` / `floor` | Round up / down | +| `mod` | Remainder | +| `power` / `sqrt` | Exponentiation / square root | +| `exp` / `ln` | `e^x` / natural log | +| `log` / `log10` / `log2` | Logarithm (1-arg = base 10, 2-arg = custom base) | +| `sign` | `-1` / `0` / `1` | +| `pi` | π | +| `greatest` / `least` | Max / min across arguments (ignoring NULL) | + +Null handling & control flow: + +| Function | Description | +| --- | --- | +| `coalesce` / `ifnull` | First non-NULL argument | +| `nullif` | NULL if the two arguments are equal | +| `CASE WHEN … THEN … ELSE … END` | Conditional expressions | +| `typeof` | Runtime type of a value | + +Aggregates (usable with `GROUP BY`): + +| Function | Description | +| --- | --- | +| `count(*)` / `count(DISTINCT col)` | Row / distinct-value count | +| `min` / `max` / `sum` / `avg` | Standard aggregates | +| `group_concat` / `string_agg(expr[, sep])` | Concatenate group values (default separator `,`) | ### DDL statements @@ -431,6 +483,15 @@ WHERE h.block_type = 'heading' AND depth = 2 AND nxt.block_type = 'list'; SELECT DISTINCT d.path FROM documents d JOIN blocks b ON b.document_id = d.id WHERE b.block_type = 'code' AND lang = 'python'; + +-- Bucket headings by depth and summarize with string/numeric functions +SELECT + CASE WHEN depth <= 1 THEN 'top-level' ELSE 'nested' END AS bucket, + count(*), + group_concat(initcap(trim(content)), ', ') AS headings +FROM blocks +WHERE block_type = 'heading' +GROUP BY CASE WHEN depth <= 1 THEN 'top-level' ELSE 'nested' END; ``` ## Architecture @@ -475,14 +536,17 @@ flowchart LR #### Layer 1 — Zone Maps (document-level skip) -Built once per document and stored in the `.mq-db` file. Checked before any block is read: +Built once per document and stored in the `.mq-db` file. Checked before any block is read. + +**Via SQL** — `SqlEngine` derives a skip automatically from the WHERE clause, for a single, non-`JOIN`ed `SELECT ... FROM blocks`: + +| WHERE conjunct | Skips documents where… | +| ---------------------------------------- | --------------------------------------------------- | +| `lang = 'X'` | `code_languages` doesn't contain `X` | +| `depth = N` (`N > 0`) | `N` exceeds `max_heading_depth` | +| `block_type = 'heading' AND content = 'X'` | `heading_contents` has no case-insensitive match for `X` | -| Field | Skips documents where… | -| ------------------- | ------------------------------------ | -| `heading_contents` | The requested heading text is absent | -| `code_languages` | The requested language tag is absent | -| `max_heading_depth` | The requested depth cannot exist | -| `tags` | The tag filter cannot match | +**Via the Rust API** — `store.query().documents(|doc| ...)` lets you filter on *any* zone-map field yourself (`heading_slugs`, `frontmatter_keys`, `title`, `tags`, …), not just the patterns `SqlEngine` recognizes automatically. #### Layer 2 — Interval Index (section hierarchy) diff --git a/src/indexes.rs b/src/indexes.rs index a8a9e45..7f17a11 100644 --- a/src/indexes.rs +++ b/src/indexes.rs @@ -26,6 +26,10 @@ //! [`crate::document::ZoneMaps`] provides document-level skipping (skip entire //! files that cannot match). These indexes operate *within* a document once //! Zone Maps have decided it is worth scanning. +//! +//! `SqlEngine` applies this automatically (see `zone_map_skip` in +//! `src/sql.rs`) for `lang =` / `depth =` / heading `content =` conjuncts, +//! but only for a single, non-`JOIN`ed `FROM blocks`. use std::collections::{BTreeMap, HashMap}; diff --git a/src/sql.rs b/src/sql.rs index 97b45e5..d7e3b06 100644 --- a/src/sql.rs +++ b/src/sql.rs @@ -23,8 +23,11 @@ //! | `under(pre, post, anc_pre, anc_post)` | O(1) interval ancestor check | //! | `json_extract(json, path)` | Extract value from JSON string | //! | `mq(program, content)` | Run an mq program against Markdown content | -//! | `count(*)`/`min`/`max`/`sum`/`avg` | Aggregates | -//! | `lower`/`upper`/`length`/`coalesce` | Scalar utilities | +//! | `count`/`min`/`max`/`sum`/`avg`/`group_concat`/`string_agg` | Aggregates (`count` and `group_concat`/`string_agg` support `DISTINCT`) | +//! | `lower`/`upper`/`length`/`trim`/`ltrim`/`rtrim`/`concat`/`concat_ws`/`replace`/`left`/`right`/`lpad`/`rpad`/`reverse`/`repeat`/`initcap`/`ascii`/`chr`/`instr`/`split_part`/`substring`/`substr`/`position` | String functions | +//! | `abs`/`round`/`ceil`/`floor`/`trunc`/`mod`/`power`/`sqrt`/`sign`/`exp`/`ln`/`log`/`log10`/`log2`/`pi`/`greatest`/`least` | Numeric functions | +//! | `coalesce`/`ifnull`/`nullif` | Null handling | +//! | `typeof`/`now`/`current_timestamp`/`current_date`/`current_time`/`CASE WHEN` | Misc | //! //! # Example //! @@ -45,11 +48,11 @@ use std::collections::HashMap; use sqlparser::{ ast::{ - BinaryOperator, CreateTable, Expr, Function, FunctionArg, FunctionArgExpr, - FunctionArguments, GroupByExpr, Insert, JoinConstraint, JoinOperator, LimitClause, - ObjectName, ObjectNamePart, ObjectType, OrderByExpr, OrderByKind, Query, Select, - SelectItem, SetExpr, Statement, TableFactor, TableObject, UnaryOperator, Value as SqlValue, - Values, + BinaryOperator, CaseWhen, CeilFloorKind, CreateTable, DateTimeField, DuplicateTreatment, + Expr, Function, FunctionArg, FunctionArgExpr, FunctionArguments, GroupByExpr, Insert, + JoinConstraint, JoinOperator, LimitClause, ObjectName, ObjectNamePart, ObjectType, + OrderByExpr, OrderByKind, Query, Select, SelectItem, SetExpr, Statement, TableFactor, + TableObject, TrimWhereField, UnaryOperator, Value as SqlValue, Values, }, dialect::GenericDialect, parser::Parser, @@ -60,7 +63,7 @@ use mq_lang::{DefaultEngine, parse_markdown_input}; use crate::{ DocumentStore, MqdbError, block::{Block, BlockType, Properties, PropertyValue}, - document::Document, + document::{Document, ZoneMaps}, indexes::{DocumentIndex, IndexHint}, store::CustomTableState, }; @@ -605,11 +608,172 @@ fn eval_expr(expr: &Expr, row: &Row) -> Value { Expr::Function(f) => eval_function_call(f, row), Expr::Nested(inner) => eval_expr(inner, row), Expr::Cast { expr, .. } => eval_expr(expr, row), + Expr::Case { + operand, + conditions, + else_result, + .. + } => eval_case(operand.as_deref(), conditions, else_result.as_deref(), row), + Expr::Trim { + expr, + trim_where, + trim_what, + trim_characters, + } => eval_trim(expr, trim_where, trim_what, trim_characters, row), + Expr::Substring { + expr, + substring_from, + substring_for, + .. + } => eval_substring(expr, substring_from, substring_for, row), + Expr::Position { expr, r#in } => eval_position(expr, r#in, row), + Expr::Ceil { expr, field } => eval_ceil_floor(expr, field, row, true), + Expr::Floor { expr, field } => eval_ceil_floor(expr, field, row, false), // Subqueries are pre-resolved by resolve_subqueries before eval _ => Value::Null, } } +fn eval_case( + operand: Option<&Expr>, + conditions: &[CaseWhen], + else_result: Option<&Expr>, + row: &Row, +) -> Value { + let operand_val = operand.map(|o| eval_expr(o, row)); + for when in conditions { + let matched = match &operand_val { + Some(ov) => *ov == eval_expr(&when.condition, row), + None => eval_expr(&when.condition, row).is_truthy(), + }; + if matched { + return eval_expr(&when.result, row); + } + } + else_result + .map(|e| eval_expr(e, row)) + .unwrap_or(Value::Null) +} + +fn eval_trim( + expr: &Expr, + trim_where: &Option, + trim_what: &Option>, + trim_characters: &Option>, + row: &Row, +) -> Value { + let s = match eval_expr(expr, row).as_str() { + Some(s) => s.to_string(), + None => return Value::Null, + }; + let chars: Vec = if let Some(w) = trim_what { + eval_expr(w, row) + .as_str() + .map(|s| s.chars().collect()) + .unwrap_or_default() + } else if let Some(cs) = trim_characters { + cs.iter() + .filter_map(|e| eval_expr(e, row).as_str().map(|s| s.to_string())) + .collect::() + .chars() + .collect() + } else { + vec![' ', '\t', '\n', '\r'] + }; + let is_trim_char = |c: char| chars.contains(&c); + let trimmed = match trim_where { + Some(TrimWhereField::Leading) => s.trim_start_matches(is_trim_char).to_string(), + Some(TrimWhereField::Trailing) => s.trim_end_matches(is_trim_char).to_string(), + _ => s.trim_matches(is_trim_char).to_string(), + }; + Value::Str(trimmed) +} + +fn eval_substring( + expr: &Expr, + substring_from: &Option>, + substring_for: &Option>, + row: &Row, +) -> Value { + let s = match eval_expr(expr, row).as_str() { + Some(s) => s.to_string(), + None => return Value::Null, + }; + let chars: Vec = s.chars().collect(); + let len = chars.len() as i64; + let start_1based = substring_from + .as_ref() + .map(|e| eval_expr(e, row).as_i64().unwrap_or(1)) + .unwrap_or(1); + let take = substring_for + .as_ref() + .map(|e| eval_expr(e, row).as_i64().unwrap_or(len)); + // SQL substring is 1-based; positions before 1 are clamped, consuming from + // the requested length as if the string started earlier. + let start_0based = (start_1based - 1).max(0) as usize; + let end_0based = match take { + Some(n) => { + let end = start_1based - 1 + n.max(0); + end.clamp(0, len) as usize + } + None => len as usize, + }; + if start_0based >= chars.len() || end_0based <= start_0based { + return Value::Str(String::new()); + } + Value::Str(chars[start_0based..end_0based].iter().collect()) +} + +fn eval_position(expr: &Expr, r#in: &Expr, row: &Row) -> Value { + let needle = eval_expr(expr, row); + let haystack = eval_expr(r#in, row); + match (needle.as_str(), haystack.as_str()) { + (Some(needle), Some(haystack)) => { + let hay_chars: Vec = haystack.chars().collect(); + let needle_chars: Vec = needle.chars().collect(); + if needle_chars.is_empty() { + return Value::Int(0); + } + for i in 0..=hay_chars.len().saturating_sub(needle_chars.len()) { + if hay_chars[i..i + needle_chars.len()] == needle_chars[..] { + return Value::Int(i as i64 + 1); + } + } + Value::Int(0) + } + _ => Value::Null, + } +} + +fn eval_ceil_floor(expr: &Expr, field: &CeilFloorKind, row: &Row, is_ceil: bool) -> Value { + let n = match eval_expr(expr, row).as_f64() { + Some(n) => n, + None => return Value::Null, + }; + let scale = match field { + CeilFloorKind::Scale(v) => match &v.value { + SqlValue::Number(s, _) => s.parse::().unwrap_or(0), + _ => 0, + }, + CeilFloorKind::DateTimeField(DateTimeField::NoDateTime) => 0, + // Date-truncation forms (`CEIL(x TO DAY)`) need calendar data we don't track. + _ => return Value::Null, + }; + let factor = 10f64.powi(scale); + let scaled = n * factor; + let rounded = if is_ceil { + scaled.ceil() + } else { + scaled.floor() + }; + let result = rounded / factor; + if scale <= 0 && result.fract() == 0.0 { + Value::Int(result as i64) + } else { + Value::Float(result) + } +} + fn eval_binary(left: &Expr, op: &BinaryOperator, right: &Expr, row: &Row) -> Value { match op { BinaryOperator::And => { @@ -681,10 +845,7 @@ fn arith_op( fn eval_function_call(f: &Function, row: &Row) -> Value { let name = f.name.0.last().map(ident_value).unwrap_or(""); // Aggregates return placeholder; resolved later - if matches!( - name.to_lowercase().as_str(), - "count" | "sum" | "min" | "max" | "avg" - ) { + if is_aggregate_name(&name.to_lowercase()) { return Value::Int(1); } let args: Vec = match &f.args { @@ -720,44 +881,398 @@ fn eval_scalar_function(name: &str, args: &[Value]) -> Value { let key = path.trim_start_matches("$.").trim_matches('"'); extract_json_key(json, key) } - "lower" => args + "mq" => { + if args.len() < 2 { + return Value::Null; + } + let program = match args[0].as_str() { + Some(s) => s.to_string(), + None => return Value::Null, + }; + let content = match args[1].as_str() { + Some(s) => s.to_string(), + None => return Value::Null, + }; + eval_mq_scalar(&program, &content) + } + + // --- string functions --- + "lower" => str_fn(args, |s| s.to_lowercase()), + "upper" => str_fn(args, |s| s.to_uppercase()), + "length" | "len" | "char_length" | "character_length" => args .first() .and_then(|v| v.as_str()) - .map(|s| Value::Str(s.to_lowercase())) + .map(|s| Value::Int(s.chars().count() as i64)) .unwrap_or(Value::Null), - "upper" => args + "trim" => str_fn(args, |s| s.trim().to_string()), + "ltrim" => { + let chars = trim_char_set(args, 1); + str_fn(args, |s| { + s.trim_start_matches(|c| chars.contains(&c)).to_string() + }) + } + "rtrim" => { + let chars = trim_char_set(args, 1); + str_fn(args, |s| { + s.trim_end_matches(|c| chars.contains(&c)).to_string() + }) + } + "concat" => Value::Str( + args.iter() + .map(|v| v.display()) + .collect::>() + .join(""), + ), + "concat_ws" => { + let sep = match args.first().and_then(|v| v.as_str()) { + Some(s) => s, + None => return Value::Null, + }; + Value::Str( + args[1..] + .iter() + .filter(|v| !matches!(v, Value::Null)) + .map(|v| v.display()) + .collect::>() + .join(sep), + ) + } + "replace" => { + if args.len() < 3 { + return Value::Null; + } + match (args[0].as_str(), args[1].as_str(), args[2].as_str()) { + (Some(s), Some(from), Some(to)) => Value::Str(s.replace(from, to)), + _ => Value::Null, + } + } + "left" => str_int_fn(args, |chars, n| { + chars[..(n.max(0) as usize).min(chars.len())] + .iter() + .collect() + }), + "right" => str_int_fn(args, |chars, n| { + let n = (n.max(0) as usize).min(chars.len()); + chars[chars.len() - n..].iter().collect() + }), + "lpad" => pad_fn(args, true), + "rpad" => pad_fn(args, false), + "reverse" => str_fn(args, |s| s.chars().rev().collect()), + "repeat" => { + if args.len() < 2 { + return Value::Null; + } + match (args[0].as_str(), args[1].as_i64()) { + (Some(s), Some(n)) => Value::Str(s.repeat(n.max(0) as usize)), + _ => Value::Null, + } + } + "initcap" => str_fn(args, |s| { + s.split(' ') + .map(|word| { + let mut c = word.chars(); + match c.next() { + Some(first) => { + first.to_uppercase().collect::() + &c.as_str().to_lowercase() + } + None => String::new(), + } + }) + .collect::>() + .join(" ") + }), + "ascii" => args .first() .and_then(|v| v.as_str()) - .map(|s| Value::Str(s.to_uppercase())) + .and_then(|s| s.chars().next()) + .map(|c| Value::Int(c as i64)) .unwrap_or(Value::Null), - "length" | "len" => args + "chr" => args .first() - .and_then(|v| v.as_str()) - .map(|s| Value::Int(s.chars().count() as i64)) + .and_then(|v| v.as_i64()) + .and_then(|n| u32::try_from(n).ok()) + .and_then(char::from_u32) + .map(|c| Value::Str(c.to_string())) .unwrap_or(Value::Null), - "coalesce" => args - .iter() - .find(|v| !matches!(v, Value::Null)) - .cloned() - .unwrap_or(Value::Null), - "mq" => { + "instr" => { if args.len() < 2 { return Value::Null; } - let program = match args[0].as_str() { - Some(s) => s.to_string(), + match (args[0].as_str(), args[1].as_str()) { + (Some(haystack), Some(needle)) => { + let hay_chars: Vec = haystack.chars().collect(); + let needle_chars: Vec = needle.chars().collect(); + if needle_chars.is_empty() { + return Value::Int(0); + } + for i in 0..=hay_chars.len().saturating_sub(needle_chars.len()) { + if hay_chars[i..i + needle_chars.len()] == needle_chars[..] { + return Value::Int(i as i64 + 1); + } + } + Value::Int(0) + } + _ => Value::Null, + } + } + "split_part" => { + if args.len() < 3 { + return Value::Null; + } + match (args[0].as_str(), args[1].as_str(), args[2].as_i64()) { + (Some(s), Some(delim), Some(n)) if n > 0 => s + .split(delim) + .nth((n - 1) as usize) + .map(|p| Value::Str(p.to_string())) + .unwrap_or(Value::Null), + _ => Value::Null, + } + } + + // --- numeric functions --- + "abs" => num_fn(args, |n| n.abs(), |n| n.abs()), + "round" => { + let n = match args.first().and_then(|v| v.as_f64()) { + Some(n) => n, None => return Value::Null, }; - let content = match args[1].as_str() { - Some(s) => s.to_string(), + let scale = args.get(1).and_then(|v| v.as_i64()).unwrap_or(0); + let factor = 10f64.powi(scale as i32); + let result = (n * factor).round() / factor; + if scale <= 0 { + Value::Int(result as i64) + } else { + Value::Float(result) + } + } + "ceil" | "ceiling" => float_fn(args, |n| n.ceil()), + "floor" => float_fn(args, |n| n.floor()), + "trunc" | "truncate" => { + let n = match args.first().and_then(|v| v.as_f64()) { + Some(n) => n, None => return Value::Null, }; - eval_mq_scalar(&program, &content) + let scale = args.get(1).and_then(|v| v.as_i64()).unwrap_or(0); + let factor = 10f64.powi(scale as i32); + let result = (n * factor).trunc() / factor; + if scale <= 0 { + Value::Int(result as i64) + } else { + Value::Float(result) + } + } + "mod" => { + if args.len() < 2 { + return Value::Null; + } + match (&args[0], &args[1]) { + (Value::Int(a), Value::Int(b)) if *b != 0 => Value::Int(a % b), + _ => match (args[0].as_f64(), args[1].as_f64()) { + (Some(a), Some(b)) if b != 0.0 => Value::Float(a % b), + _ => Value::Null, + }, + } + } + "power" | "pow" => { + if args.len() < 2 { + return Value::Null; + } + match (args[0].as_f64(), args[1].as_f64()) { + (Some(a), Some(b)) => Value::Float(a.powf(b)), + _ => Value::Null, + } + } + "sqrt" => float_fn(args, |n| n.sqrt()), + "sign" => float_fn(args, |n| { + if n > 0.0 { + 1.0 + } else if n < 0.0 { + -1.0 + } else { + 0.0 + } + }), + "exp" => float_fn(args, |n| n.exp()), + "ln" => float_fn(args, |n| n.ln()), + "log10" => float_fn(args, |n| n.log10()), + "log2" => float_fn(args, |n| n.log2()), + "log" => { + let n = match args.first().and_then(|v| v.as_f64()) { + Some(n) => n, + None => return Value::Null, + }; + match args.get(1).and_then(|v| v.as_f64()) { + Some(base) => Value::Float(n.log(base)), + None => Value::Float(n.log10()), + } + } + "pi" => Value::Float(std::f64::consts::PI), + "greatest" => args + .iter() + .filter(|v| !matches!(v, Value::Null)) + .cloned() + .max_by(|a, b| a.cmp_val(b).unwrap_or(std::cmp::Ordering::Equal)) + .unwrap_or(Value::Null), + "least" => args + .iter() + .filter(|v| !matches!(v, Value::Null)) + .cloned() + .min_by(|a, b| a.cmp_val(b).unwrap_or(std::cmp::Ordering::Equal)) + .unwrap_or(Value::Null), + + // --- null handling --- + "coalesce" | "ifnull" => args + .iter() + .find(|v| !matches!(v, Value::Null)) + .cloned() + .unwrap_or(Value::Null), + "nullif" => { + if args.len() < 2 { + return Value::Null; + } + if args[0] == args[1] { + Value::Null + } else { + args[0].clone() + } + } + + // --- misc --- + "typeof" => Value::Str( + match args.first() { + Some(Value::Str(_)) => "text", + Some(Value::Int(_)) => "integer", + Some(Value::Float(_)) => "float", + Some(Value::Bool(_)) => "boolean", + Some(Value::Null) | None => "null", + } + .to_string(), + ), + "now" | "current_timestamp" => Value::Str(current_datetime_utc(true, true)), + "current_date" => Value::Str(current_datetime_utc(true, false)), + "current_time" => Value::Str(current_datetime_utc(false, true)), + _ => Value::Null, + } +} + +fn str_fn(args: &[Value], f: impl Fn(&str) -> String) -> Value { + args.first() + .and_then(|v| v.as_str()) + .map(|s| Value::Str(f(s))) + .unwrap_or(Value::Null) +} + +fn str_int_fn(args: &[Value], f: impl Fn(&[char], i64) -> String) -> Value { + if args.len() < 2 { + return Value::Null; + } + match (args[0].as_str(), args[1].as_i64()) { + (Some(s), Some(n)) => { + let chars: Vec = s.chars().collect(); + Value::Str(f(&chars, n)) } _ => Value::Null, } } +fn num_fn(args: &[Value], int_f: impl Fn(i64) -> i64, flt_f: impl Fn(f64) -> f64) -> Value { + match args.first() { + Some(Value::Int(n)) => Value::Int(int_f(*n)), + Some(v) => v + .as_f64() + .map(|n| Value::Float(flt_f(n))) + .unwrap_or(Value::Null), + None => Value::Null, + } +} + +fn float_fn(args: &[Value], f: impl Fn(f64) -> f64) -> Value { + args.first() + .and_then(|v| v.as_f64()) + .map(|n| Value::Float(f(n))) + .unwrap_or(Value::Null) +} + +/// Builds the set of characters TRIM/LTRIM/RTRIM should strip, defaulting to +/// whitespace when no explicit character argument is given. +fn trim_char_set(args: &[Value], chars_idx: usize) -> Vec { + args.get(chars_idx) + .and_then(|v| v.as_str()) + .map(|s| s.chars().collect()) + .unwrap_or_else(|| vec![' ', '\t', '\n', '\r']) +} + +fn pad_fn(args: &[Value], left: bool) -> Value { + if args.len() < 2 { + return Value::Null; + } + let s = match args[0].as_str() { + Some(s) => s, + None => return Value::Null, + }; + let target_len = match args[1].as_i64() { + Some(n) => n.max(0) as usize, + None => return Value::Null, + }; + let pad_str = args.get(2).and_then(|v| v.as_str()).unwrap_or(" "); + let mut chars: Vec = s.chars().collect(); + if chars.len() >= target_len { + chars.truncate(target_len); + return Value::Str(chars.into_iter().collect()); + } + if pad_str.is_empty() { + return Value::Str(s.to_string()); + } + let pad_chars: Vec = pad_str.chars().collect(); + let needed = target_len - chars.len(); + let padding: Vec = pad_chars.iter().cycle().take(needed).copied().collect(); + if left { + Value::Str(padding.into_iter().chain(chars).collect()) + } else { + chars.extend(padding); + Value::Str(chars.into_iter().collect()) + } +} + +/// Returns the current UTC time formatted for `now()`/`current_timestamp`/ +/// `current_date`/`current_time`. No date columns exist in the schema, so +/// this only needs to support clock-style scalar lookups, not arithmetic. +fn current_datetime_utc(with_date: bool, with_time: bool) -> String { + let secs = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0); + let days = (secs / 86400) as i64; + let time_of_day = secs % 86400; + let (y, m, d) = civil_from_days(days); + let (h, mi, s) = ( + time_of_day / 3600, + (time_of_day / 60) % 60, + time_of_day % 60, + ); + match (with_date, with_time) { + (true, true) => format!("{y:04}-{m:02}-{d:02} {h:02}:{mi:02}:{s:02}"), + (true, false) => format!("{y:04}-{m:02}-{d:02}"), + _ => format!("{h:02}:{mi:02}:{s:02}"), + } +} + +/// Howard Hinnant's `civil_from_days` algorithm: converts a day count +/// since the Unix epoch (1970-01-01) into a proleptic-Gregorian (year, month, day). +fn civil_from_days(z: i64) -> (i64, u32, u32) { + let z = z + 719468; + let era = if z >= 0 { z } else { z - 146096 } / 146097; + let doe = (z - era * 146097) as u64; + let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; + let y = yoe as i64 + era * 400; + let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); + let mp = (5 * doy + 2) / 153; + let d = (doy - (153 * mp + 2) / 5 + 1) as u32; + let m = if mp < 10 { mp + 3 } else { mp - 9 } as u32; + let y = if m <= 2 { y + 1 } else { y }; + (y, m, d) +} + fn eval_mq_scalar(program: &str, content: &str) -> Value { let mut engine = DefaultEngine::default(); engine.load_builtin_module(); @@ -1196,12 +1711,15 @@ impl<'a> SqlEngine<'a> { }; // 1. Materialise FROM — with index-based predicate pushdown - let hint = select - .selection - .as_ref() + let where_expr = select.selection.as_ref(); + let hint = where_expr .map(analyze_where_for_index) .unwrap_or(IndexHint::FullScan); - let mut rows = self.materialise_from_with_hint(&select.from, &hint)?; + // Unlike `hint`, a skip has no later row-by-row recheck, so only + // allow it for a single un-joined FROM table (no alias ambiguity). + let zone_filter = + where_expr.filter(|_| select.from.len() == 1 && select.from[0].joins.is_empty()); + let mut rows = self.materialise_from_with_hint(&select.from, &hint, zone_filter)?; // 2. WHERE (full predicate evaluation; index only pre-filtered) if let Some(where_expr) = &select.selection { @@ -1277,6 +1795,7 @@ impl<'a> SqlEngine<'a> { &self, from: &[sqlparser::ast::TableWithJoins], hint: &IndexHint, + zone_filter: Option<&Expr>, ) -> Result, MqdbError> { if from.is_empty() { return Ok(vec![Row { @@ -1284,10 +1803,10 @@ impl<'a> SqlEngine<'a> { values: vec![], }]); } - let mut rows = self.table_rows_with_hint(&from[0].relation, hint)?; + let mut rows = self.table_rows_with_hint(&from[0].relation, hint, zone_filter)?; for join in &from[0].joins { // Joined tables always full-scan (join partner) - let right = self.table_rows_with_hint(&join.relation, &IndexHint::FullScan)?; + let right = self.table_rows_with_hint(&join.relation, &IndexHint::FullScan, None)?; rows = cross_join(rows, right); match &join.join_operator { JoinOperator::Inner(JoinConstraint::On(on)) @@ -1301,10 +1820,11 @@ impl<'a> SqlEngine<'a> { } } for twj in from.iter().skip(1) { - let right = self.table_rows_with_hint(&twj.relation, &IndexHint::FullScan)?; + let right = self.table_rows_with_hint(&twj.relation, &IndexHint::FullScan, None)?; rows = cross_join(rows, right); for join in &twj.joins { - let right2 = self.table_rows_with_hint(&join.relation, &IndexHint::FullScan)?; + let right2 = + self.table_rows_with_hint(&join.relation, &IndexHint::FullScan, None)?; rows = cross_join(rows, right2); } } @@ -1315,6 +1835,7 @@ impl<'a> SqlEngine<'a> { &self, factor: &TableFactor, hint: &IndexHint, + zone_filter: Option<&Expr>, ) -> Result, MqdbError> { let (table_name, alias) = match factor { TableFactor::Table { name, alias, .. } => { @@ -1332,6 +1853,14 @@ impl<'a> SqlEngine<'a> { let mut global_idx: u32 = 0; for (doc, doc_idx) in self.documents_with_indexes() { + // Zone-map document skip: prove no block in this document + // can match before reading any of them. + if let Some(we) = zone_filter + && zone_map_skip(&doc.zone_maps, we) + { + global_idx += doc.blocks.len() as u32; + continue; + } // Try index-based access first if let Some(local_indices) = hint.resolve(doc_idx) { // Only materialise the pre-filtered blocks @@ -1585,10 +2114,17 @@ fn has_aggregate(projection: &[SelectItem]) -> bool { fn is_agg_expr(expr: &Expr) -> bool { matches!(expr, Expr::Function(f) if { let name = f.name.0.last().map(ident_value).unwrap_or("").to_lowercase(); - matches!(name.as_str(), "count" | "sum" | "min" | "max" | "avg") + is_aggregate_name(&name) }) } +fn is_aggregate_name(name: &str) -> bool { + matches!( + name, + "count" | "sum" | "min" | "max" | "avg" | "group_concat" | "string_agg" + ) +} + fn eval_agg_row( projection: &[SelectItem], group_by_exprs: &[Expr], @@ -1612,7 +2148,27 @@ fn eval_agg_row( .unwrap_or("") .to_lowercase(); match name.as_str() { + "count" if is_distinct(f) => { + let mut seen: Vec = Vec::new(); + for r in group_rows { + let v = agg_arg(f, r); + if !matches!(v, Value::Null) && !seen.contains(&v) { + seen.push(v); + } + } + seen.len().to_string() + } "count" => group_rows.len().to_string(), + "group_concat" | "string_agg" => { + let sep = agg_separator(f); + group_rows + .iter() + .map(|r| agg_arg(f, r)) + .filter(|v| !matches!(v, Value::Null)) + .map(|v| v.display()) + .collect::>() + .join(&sep) + } "sum" => { let sum: f64 = group_rows .iter() @@ -1676,6 +2232,26 @@ fn agg_arg(f: &Function, row: &Row) -> Value { .unwrap_or(Value::Null) } +fn is_distinct(f: &Function) -> bool { + matches!( + &f.args, + FunctionArguments::List(al) if al.duplicate_treatment == Some(DuplicateTreatment::Distinct) + ) +} + +/// Separator for `group_concat(expr[, sep])` / `string_agg(expr, sep)`; the +/// second argument is expected to be a literal, so it's read straight off +/// the AST rather than through `eval_expr` (which needs a row). +fn agg_separator(f: &Function) -> String { + if let FunctionArguments::List(al) = &f.args + && let Some(FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(v)))) = al.args.get(1) + && let Value::Str(s) = eval_sql_value(&v.value) + { + return s; + } + ",".to_string() +} + fn expr_structurally_eq(a: &Expr, b: &Expr) -> bool { format!("{:?}", a) == format!("{:?}", b) } @@ -1717,6 +2293,104 @@ fn apply_limit(mut rows: Vec>, limit: Option<&Expr>) -> Vec Vec<&Expr> { + match expr { + Expr::BinaryOp { + left, + op: BinaryOperator::And, + right, + } => { + let mut out = flatten_and_conjuncts(left); + out.extend(flatten_and_conjuncts(right)); + out + } + Expr::Nested(inner) => flatten_and_conjuncts(inner), + other => vec![other], + } +} + +/// Decides whether a whole document can be skipped using [`ZoneMaps`], +/// without reading any of its blocks. Unlike [`IndexHint`], a wrong skip +/// here silently drops matching rows, so this only returns `true` when it +/// can prove no block in the document satisfies `where_expr`. +fn zone_map_skip(zone_maps: &ZoneMaps, where_expr: &Expr) -> bool { + let mut eq_block_type: Option = None; + let mut eq_content: Option = None; + let mut eq_lang: Option = None; + let mut eq_depth: Option = None; + + for conjunct in flatten_and_conjuncts(where_expr) { + let Expr::BinaryOp { + left, + op: BinaryOperator::Eq, + right, + } = conjunct + else { + continue; + }; + let col = expr_col_name(left).or_else(|| expr_col_name(right)); + let val = expr_str_val(right).or_else(|| expr_str_val(left)); + let int_val = expr_int_val(right).or_else(|| expr_int_val(left)); + + match col.as_deref() { + Some("block_type") => { + if let Some(s) = val.as_deref() + && let Some(bt) = BlockType::from_str(s) + { + eq_block_type = Some(bt); + } + } + Some("content") => eq_content = val, + // lang = '' means "no lang" (matches non-code blocks), which + // code_languages says nothing about. + Some("lang") => { + if let Some(s) = val + && !s.is_empty() + { + eq_lang = Some(s); + } + } + // depth = 0 means "no heading depth" (matches non-heading + // blocks), which max_heading_depth says nothing about. + Some("depth") => { + if let Some(n) = int_val + && let Ok(n) = u8::try_from(n) + && n > 0 + { + eq_depth = Some(n); + } + } + _ => {} + } + } + + if let Some(lang) = &eq_lang + && !zone_maps.code_languages.contains(lang) + { + return true; + } + if let Some(depth) = eq_depth + && depth > zone_maps.max_heading_depth + { + return true; + } + // Only safe when `block_type = 'heading'` is also required — `content` + // alone could match a non-heading block. + if let Some(content) = &eq_content + && eq_block_type == Some(BlockType::Heading) + && !zone_maps + .heading_contents + .iter() + .any(|h| h.eq_ignore_ascii_case(content)) + { + return true; + } + + false +} + /// Inspect the WHERE expression and return the best [`IndexHint`]. /// /// Only analyses the *outermost* conjunct that can be served by an index. @@ -1920,6 +2594,16 @@ mod tests { s } + // Doc B (no code, depth 1) sits between two rust/depth-3 docs. + fn make_multi_doc_store() -> DocumentStore { + let mut s = DocumentStore::new(); + s.add_str("# A\n\n```rust\nfn a(){}\n```\n").unwrap(); + s.add_str("# B\n\nParagraph\n").unwrap(); + s.add_str("# C\n\n## C2\n\n### C3\n\n```rust\nfn c(){}\n```\n") + .unwrap(); + s + } + #[test] fn test_sql_select_all_blocks() { let store = make_store(); @@ -2275,4 +2959,188 @@ mod tests { assert_eq!(out.rows.len(), 1); assert_eq!(out.rows[0][0], "NULL"); } + + fn eval_one(sql: &str) -> String { + let store = DocumentStore::new(); + let engine = SqlEngine::new(&store).unwrap(); + engine.execute(sql).unwrap().rows[0][0].clone() + } + + #[rstest] + // string functions + #[case("SELECT lower('Hello')", "hello")] + #[case("SELECT upper('Hello')", "HELLO")] + #[case("SELECT length('héllo')", "5")] + #[case("SELECT trim(' hi ')", "hi")] + #[case("SELECT ltrim(' hi ')", "hi ")] + #[case("SELECT rtrim(' hi ')", " hi")] + #[case("SELECT trim(LEADING 'x' FROM 'xxhixx')", "hixx")] + #[case("SELECT trim(TRAILING 'x' FROM 'xxhixx')", "xxhi")] + #[case("SELECT trim('x' FROM 'xxhixx')", "hi")] + #[case("SELECT concat('a', 'b', 'c')", "abc")] + #[case("SELECT concat_ws('-', 'a', 'b', NULL, 'c')", "a-b-c")] + #[case("SELECT replace('foobar', 'o', '0')", "f00bar")] + #[case("SELECT left('hello', 3)", "hel")] + #[case("SELECT right('hello', 3)", "llo")] + #[case("SELECT lpad('7', 3, '0')", "007")] + #[case("SELECT rpad('7', 3, '0')", "700")] + #[case("SELECT reverse('hello')", "olleh")] + #[case("SELECT repeat('ab', 3)", "ababab")] + #[case("SELECT initcap('hello world')", "Hello World")] + #[case("SELECT ascii('A')", "65")] + #[case("SELECT chr(65)", "A")] + #[case("SELECT instr('hello world', 'world')", "7")] + #[case("SELECT position('world' in 'hello world')", "7")] + #[case("SELECT split_part('a,b,c', ',', 2)", "b")] + #[case("SELECT substring('hello world', 1, 5)", "hello")] + #[case("SELECT substring('hello world' from 7)", "world")] + #[case("SELECT substr('hello world', 7, 5)", "world")] + // numeric functions + #[case("SELECT abs(-5)", "5")] + #[case("SELECT abs(-5.5)", "5.5")] + #[case("SELECT round(3.456, 2)", "3.46")] + #[case("SELECT round(3.5)", "4")] + #[case("SELECT ceil(3.1)", "4")] + #[case("SELECT floor(3.9)", "3")] + #[case("SELECT trunc(3.789, 1)", "3.7")] + #[case("SELECT mod(10, 3)", "1")] + #[case("SELECT power(2, 10)", "1024")] + #[case("SELECT sqrt(16)", "4")] + #[case("SELECT sign(-3)", "-1")] + #[case("SELECT greatest(3, 7, 2)", "7")] + #[case("SELECT least(3, 7, 2)", "2")] + // null handling + #[case("SELECT coalesce(NULL, NULL, 'x')", "x")] + #[case("SELECT ifnull(NULL, 'y')", "y")] + #[case("SELECT nullif('a', 'a')", "NULL")] + #[case("SELECT nullif('a', 'b')", "a")] + // misc + #[case("SELECT typeof('x')", "text")] + #[case("SELECT typeof(1)", "integer")] + // CASE + #[case( + "SELECT CASE WHEN 1 = 2 THEN 'a' WHEN 1 = 1 THEN 'b' ELSE 'c' END", + "b" + )] + #[case("SELECT CASE 2 WHEN 1 THEN 'a' WHEN 2 THEN 'b' ELSE 'c' END", "b")] + #[case("SELECT CASE WHEN 1 = 2 THEN 'a' ELSE 'c' END", "c")] + fn test_sql_scalar_functions(#[case] sql: &str, #[case] expected: &str) { + assert_eq!(eval_one(sql), expected); + } + + #[test] + fn test_sql_group_concat() { + let store = make_store(); + let engine = SqlEngine::new(&store).unwrap(); + let out = engine + .execute("SELECT group_concat(content) FROM blocks WHERE block_type = 'heading'") + .unwrap(); + assert_eq!(out.rows[0][0], "Doc,Architecture,Other"); + } + + #[test] + fn test_sql_string_agg_custom_separator() { + let store = make_store(); + let engine = SqlEngine::new(&store).unwrap(); + let out = engine + .execute("SELECT string_agg(content, ' | ') FROM blocks WHERE block_type = 'heading'") + .unwrap(); + assert_eq!(out.rows[0][0], "Doc | Architecture | Other"); + } + + #[test] + fn test_sql_count_distinct() { + let store = make_store(); + let engine = SqlEngine::new(&store).unwrap(); + let out = engine + .execute("SELECT count(DISTINCT block_type) FROM blocks") + .unwrap(); + assert_eq!(out.rows[0][0], "3"); + } + + // doc B has no code at all; A and C's rust blocks must still come through. + #[test] + fn test_sql_zone_map_skip_by_lang() { + let store = make_multi_doc_store(); + let engine = SqlEngine::new(&store).unwrap(); + let out = engine + .execute("SELECT content FROM blocks WHERE lang = 'rust' ORDER BY content") + .unwrap(); + let contents: Vec<&str> = out.rows.iter().map(|r| r[0].as_str()).collect(); + assert_eq!(contents, vec!["fn a(){}", "fn c(){}"]); + } + + // depth=3 only exists in doc C; A and B (max depth 1) must be skipped. + #[test] + fn test_sql_zone_map_skip_by_depth() { + let store = make_multi_doc_store(); + let engine = SqlEngine::new(&store).unwrap(); + let out = engine + .execute("SELECT content FROM blocks WHERE depth = 3") + .unwrap(); + assert_eq!(out.rows.len(), 1); + assert_eq!(out.rows[0][0], "C3"); + } + + // Only doc B has a heading named "B"; requires block_type='heading' too. + #[test] + fn test_sql_zone_map_skip_by_heading_content() { + let store = make_multi_doc_store(); + let engine = SqlEngine::new(&store).unwrap(); + let out = engine + .execute("SELECT content FROM blocks WHERE block_type = 'heading' AND content = 'B'") + .unwrap(); + assert_eq!(out.rows.len(), 1); + assert_eq!(out.rows[0][0], "B"); + } + + // `lang = ''` means "no lang"; must never trigger a code-language skip. + #[test] + fn test_sql_zone_map_no_skip_on_empty_lang() { + let store = make_multi_doc_store(); + let engine = SqlEngine::new(&store).unwrap(); + let out = engine + .execute("SELECT content FROM blocks WHERE lang = ''") + .unwrap(); + let contents: Vec<&str> = out.rows.iter().map(|r| r[0].as_str()).collect(); + assert!(contents.contains(&"B"), "doc B must not be skipped"); + assert!(contents.contains(&"Paragraph")); + } + + // `id` must stay stable regardless of which documents get skipped. + #[test] + fn test_sql_zone_map_skip_preserves_block_ids() { + let store = make_multi_doc_store(); + let engine = SqlEngine::new(&store).unwrap(); + let full = engine.execute("SELECT id, content FROM blocks").unwrap(); + let filtered = engine + .execute("SELECT id, content FROM blocks WHERE lang = 'rust'") + .unwrap(); + assert_eq!(filtered.rows.len(), 2); + for row in &filtered.rows { + let same_id = full.rows.iter().find(|r| r[0] == row[0]).unwrap(); + assert_eq!( + same_id[1], row[1], + "id {} must reference the same block content in both queries", + row[0] + ); + } + } + + // Zone-map skip is disabled whenever FROM has a join (see `exec_query`). + // Just checks a join with a recognized conjunct still scans normally. + #[test] + fn test_sql_zone_map_skip_disabled_for_joins() { + let store = make_multi_doc_store(); + let engine = SqlEngine::new(&store).unwrap(); + let out = engine + .execute( + "SELECT h.content, c.content FROM blocks h + JOIN blocks c ON c.document_id = h.document_id AND c.block_type = 'code' + WHERE h.block_type = 'heading'", + ) + .unwrap(); + let headings: Vec<&str> = out.rows.iter().map(|r| r[0].as_str()).collect(); + assert_eq!(headings, vec!["A", "C", "C2", "C3"]); + } }