diff --git a/Cargo.lock b/Cargo.lock index 69390359..c8646ce6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3994,6 +3994,7 @@ dependencies = [ "aes-gcm", "crc32c", "fluxbench", + "getrandom 0.3.4", "io-uring", "libc", "memmap2", diff --git a/nodedb-columnar/src/mutation.rs b/nodedb-columnar/src/mutation.rs index 4f08d904..15f90b9e 100644 --- a/nodedb-columnar/src/mutation.rs +++ b/nodedb-columnar/src/mutation.rs @@ -85,7 +85,7 @@ impl MutationEngine { } // Generate WAL record BEFORE applying the mutation. - let row_data = encode_row_for_wal(values); + let row_data = encode_row_for_wal(values)?; let wal = ColumnarWalRecord::InsertRow { collection: self.collection.clone(), row_data, @@ -310,6 +310,13 @@ impl MutationEngine { self.memtable.get_row(row_idx) } + /// The segment ID that will be assigned to the next flushed segment. + /// + /// Use this to obtain the ID to pass to `on_memtable_flushed`. + pub fn next_segment_id(&self) -> u32 { + self.next_segment_id + } + /// Whether a segment should be compacted based on its delete ratio. pub fn should_compact(&self, segment_id: u32, total_rows: u64) -> bool { self.delete_bitmaps diff --git a/nodedb-columnar/src/wal_record.rs b/nodedb-columnar/src/wal_record.rs index 0a00006b..a8462242 100644 --- a/nodedb-columnar/src/wal_record.rs +++ b/nodedb-columnar/src/wal_record.rs @@ -91,7 +91,9 @@ impl ColumnarWalRecord { /// Each value is written as: [type_tag: u8][value_bytes]. /// This is more compact than MessagePack for typed columns and enables /// direct replay into the memtable without schema interpretation overhead. -pub fn encode_row_for_wal(values: &[nodedb_types::value::Value]) -> Vec { +pub fn encode_row_for_wal( + values: &[nodedb_types::value::Value], +) -> Result, crate::error::ColumnarError> { use nodedb_types::value::Value; let mut buf = Vec::with_capacity(values.len() * 10); // Rough estimate. @@ -152,14 +154,63 @@ pub fn encode_row_for_wal(values: &[nodedb_types::value::Value]) -> Vec { _ => { // Geometry and other complex types: serialize as JSON bytes. buf.push(10); - let json = sonic_rs::to_vec(value).unwrap_or_default(); + let json = sonic_rs::to_vec(value).map_err(|e| { + crate::error::ColumnarError::Serialization(format!( + "failed to serialize value as JSON: {e}" + )) + })?; buf.extend_from_slice(&(json.len() as u32).to_le_bytes()); buf.extend_from_slice(&json); } } } - buf + Ok(buf) +} + +/// Maximum length for a variable-length field in a WAL record (256 MiB). +/// Prevents OOM from crafted/corrupt records with bogus length prefixes. +const MAX_FIELD_LEN: usize = 256 * 1024 * 1024; + +/// Read exactly `n` bytes from `data` at `cursor`, advancing cursor. +/// Returns `Err` if not enough bytes remain. +fn read_slice<'a>( + data: &'a [u8], + cursor: &mut usize, + n: usize, + context: &str, +) -> Result<&'a [u8], crate::error::ColumnarError> { + let end = cursor.checked_add(n).ok_or_else(|| { + crate::error::ColumnarError::Serialization(format!("overflow in {context}")) + })?; + if end > data.len() { + return Err(crate::error::ColumnarError::Serialization(format!( + "truncated {context}: need {n} bytes at offset {cursor}, have {}", + data.len().saturating_sub(*cursor) + ))); + } + let slice = &data[*cursor..end]; + *cursor = end; + Ok(slice) +} + +/// Read a u32 length prefix, validate it against MAX_FIELD_LEN, then read +/// that many bytes. Returns the payload slice. +fn read_length_prefixed<'a>( + data: &'a [u8], + cursor: &mut usize, + context: &str, +) -> Result<&'a [u8], crate::error::ColumnarError> { + let len_bytes = read_slice(data, cursor, 4, context)?; + let len = u32::from_le_bytes(len_bytes.try_into().map_err(|_| { + crate::error::ColumnarError::Serialization(format!("truncated {context} len")) + })?) as usize; + if len > MAX_FIELD_LEN { + return Err(crate::error::ColumnarError::Serialization(format!( + "{context} length {len} exceeds maximum {MAX_FIELD_LEN}" + ))); + } + read_slice(data, cursor, len, context) } /// Decode a row from the columnar wire format back into Values. @@ -172,37 +223,40 @@ pub fn decode_row_from_wal( let mut cursor = 0; while cursor < data.len() { - let tag = data[cursor]; - cursor += 1; + let tag_slice = read_slice(data, &mut cursor, 1, "tag")?; + let tag = tag_slice[0]; let value = match tag { 0 => Value::Null, 1 => { - let v = i64::from_le_bytes(data[cursor..cursor + 8].try_into().map_err(|_| { + let bytes = read_slice(data, &mut cursor, 8, "i64")?; + let v = i64::from_le_bytes(bytes.try_into().map_err(|_| { crate::error::ColumnarError::Serialization("truncated i64".into()) })?); - cursor += 8; Value::Integer(v) } 2 => { - let v = f64::from_le_bytes(data[cursor..cursor + 8].try_into().map_err(|_| { + let bytes = read_slice(data, &mut cursor, 8, "f64")?; + let v = f64::from_le_bytes(bytes.try_into().map_err(|_| { crate::error::ColumnarError::Serialization("truncated f64".into()) })?); - cursor += 8; Value::Float(v) } 3 => { - let v = data[cursor] != 0; - cursor += 1; - Value::Bool(v) + let bytes = read_slice(data, &mut cursor, 1, "bool")?; + Value::Bool(bytes[0] != 0) } 4 | 5 | 8 => { - let len = u32::from_le_bytes(data[cursor..cursor + 4].try_into().map_err(|_| { - crate::error::ColumnarError::Serialization("truncated len".into()) - })?) as usize; - cursor += 4; - let bytes = &data[cursor..cursor + len]; - cursor += len; + let bytes = read_length_prefixed( + data, + &mut cursor, + match tag { + 4 => "string", + 5 => "bytes", + 8 => "uuid", + _ => unreachable!(), + }, + )?; match tag { 4 => Value::String(String::from_utf8_lossy(bytes).into_owned()), 5 => Value::Bytes(bytes.to_vec()), @@ -211,43 +265,41 @@ pub fn decode_row_from_wal( } } 6 => { - let micros = - i64::from_le_bytes(data[cursor..cursor + 8].try_into().map_err(|_| { - crate::error::ColumnarError::Serialization("truncated timestamp".into()) - })?); - cursor += 8; + let bytes = read_slice(data, &mut cursor, 8, "timestamp")?; + let micros = i64::from_le_bytes(bytes.try_into().map_err(|_| { + crate::error::ColumnarError::Serialization("truncated timestamp".into()) + })?); Value::DateTime(nodedb_types::datetime::NdbDateTime::from_micros(micros)) } 7 => { - let mut bytes = [0u8; 16]; - bytes.copy_from_slice(&data[cursor..cursor + 16]); - cursor += 16; - Value::Decimal(rust_decimal::Decimal::deserialize(bytes)) + let bytes = read_slice(data, &mut cursor, 16, "decimal")?; + let mut arr = [0u8; 16]; + arr.copy_from_slice(bytes); + Value::Decimal(rust_decimal::Decimal::deserialize(arr)) } 9 => { - let count = - u32::from_le_bytes(data[cursor..cursor + 4].try_into().map_err(|_| { - crate::error::ColumnarError::Serialization("truncated vector count".into()) - })?) as usize; - cursor += 4; + let count_bytes = read_slice(data, &mut cursor, 4, "vector count")?; + let count = u32::from_le_bytes(count_bytes.try_into().map_err(|_| { + crate::error::ColumnarError::Serialization("truncated vector count".into()) + })?) as usize; + if count > MAX_FIELD_LEN / 4 { + return Err(crate::error::ColumnarError::Serialization(format!( + "vector count {count} exceeds maximum {}", + MAX_FIELD_LEN / 4 + ))); + } let mut arr = Vec::with_capacity(count); for _ in 0..count { - let f = - f32::from_le_bytes(data[cursor..cursor + 4].try_into().map_err(|_| { - crate::error::ColumnarError::Serialization("truncated f32".into()) - })?); - cursor += 4; + let fb = read_slice(data, &mut cursor, 4, "vector f32")?; + let f = f32::from_le_bytes(fb.try_into().map_err(|_| { + crate::error::ColumnarError::Serialization("truncated f32".into()) + })?); arr.push(Value::Float(f as f64)); } Value::Array(arr) } 10 => { - let len = u32::from_le_bytes(data[cursor..cursor + 4].try_into().map_err(|_| { - crate::error::ColumnarError::Serialization("truncated json len".into()) - })?) as usize; - cursor += 4; - let json_bytes = &data[cursor..cursor + len]; - cursor += len; + let json_bytes = read_length_prefixed(data, &mut cursor, "json")?; sonic_rs::from_slice(json_bytes).unwrap_or(Value::Null) } _ => { @@ -316,7 +368,7 @@ mod tests { Value::Array(vec![Value::Float(1.0), Value::Float(2.0)]), ]; - let encoded = encode_row_for_wal(&values); + let encoded = encode_row_for_wal(&values).expect("encode"); let decoded = decode_row_from_wal(&encoded).expect("decode"); assert_eq!(decoded.len(), values.len()); @@ -335,4 +387,70 @@ mod tests { ); assert_eq!(decoded[8], Value::Null); } + + #[test] + fn decode_truncated_i64_returns_error() { + // Tag 1 (i64) requires 8 payload bytes; supply none. + // Today the slice index `data[cursor..cursor+8]` panics with an index + // out-of-bounds. After the fix, `try_into()` returns the + // Serialization error instead. + let result = decode_row_from_wal(&[1]); + assert!( + result.is_err(), + "truncated i64 payload must return Err, not panic" + ); + } + + #[test] + fn decode_truncated_string_returns_error() { + // Tag 4 (string): length prefix says 255 bytes but the slice ends + // immediately after the 4-byte length field. The read of + // `data[cursor..cursor+255]` panics today; after the fix it errors. + let input = { + let mut v = vec![4u8]; // tag = string + v.extend_from_slice(&255u32.to_le_bytes()); // len = 255 + // no payload bytes follow + v + }; + let result = decode_row_from_wal(&input); + assert!( + result.is_err(), + "truncated string payload must return Err, not panic" + ); + } + + #[test] + fn decode_huge_vector_count_returns_error() { + // Tag 9 (vector array): count = 0x7FFFFFFF. After reading the count, + // the very first iteration tries to read 4 bytes of f32 from an empty + // slice, which panics today. After the fix the loop errors out cleanly + // before any allocation proportional to count is attempted. + let input = { + let mut v = vec![9u8]; // tag = vector array + v.extend_from_slice(&0x7FFF_FFFFu32.to_le_bytes()); // count + // no f32 bytes follow + v + }; + let result = decode_row_from_wal(&input); + assert!( + result.is_err(), + "huge vector count with no payload must return Err, not panic or OOM" + ); + } + + #[test] + fn decode_truncated_decimal_returns_error() { + // Tag 7 (Decimal) requires 16 bytes; supply only 4. + // `data[cursor..cursor+16]` panics today; after the fix it errors. + let input = { + let mut v = vec![7u8]; // tag = decimal + v.extend_from_slice(&[0u8; 4]); // only 4 bytes, need 16 + v + }; + let result = decode_row_from_wal(&input); + assert!( + result.is_err(), + "truncated decimal payload must return Err, not panic" + ); + } } diff --git a/nodedb-query/src/expr_parse.rs b/nodedb-query/src/expr_parse.rs index 4ccd9058..4c4f4762 100644 --- a/nodedb-query/src/expr_parse.rs +++ b/nodedb-query/src/expr_parse.rs @@ -26,7 +26,7 @@ use nodedb_types::Value; pub fn parse_generated_expr(text: &str) -> Result<(SqlExpr, Vec), String> { let tokens = tokenize(text)?; let mut pos = 0; - let expr = parse_expr(&tokens, &mut pos)?; + let expr = parse_expr(&tokens, &mut pos, &mut 0)?; if pos < tokens.len() { return Err(format!( "unexpected token after expression: '{}'", @@ -195,16 +195,20 @@ fn tokenize(input: &str) -> Result, String> { // ── Recursive descent parser ────────────────────────────────────────── +/// Maximum recursion depth for nested parentheses / sub-expressions. +/// Exceeding this limit returns `Err` instead of overflowing the stack. +const MAX_EXPR_DEPTH: usize = 128; + /// Parse an expression (lowest precedence: OR). -fn parse_expr(tokens: &[Token], pos: &mut usize) -> Result { - parse_or(tokens, pos) +fn parse_expr(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result { + parse_or(tokens, pos, depth) } -fn parse_or(tokens: &[Token], pos: &mut usize) -> Result { - let mut left = parse_and(tokens, pos)?; +fn parse_or(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result { + let mut left = parse_and(tokens, pos, depth)?; while peek_keyword(tokens, *pos, "OR") { *pos += 1; - let right = parse_and(tokens, pos)?; + let right = parse_and(tokens, pos, depth)?; left = SqlExpr::BinaryOp { left: Box::new(left), op: BinaryOp::Or, @@ -214,11 +218,11 @@ fn parse_or(tokens: &[Token], pos: &mut usize) -> Result { Ok(left) } -fn parse_and(tokens: &[Token], pos: &mut usize) -> Result { - let mut left = parse_comparison(tokens, pos)?; +fn parse_and(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result { + let mut left = parse_comparison(tokens, pos, depth)?; while peek_keyword(tokens, *pos, "AND") { *pos += 1; - let right = parse_comparison(tokens, pos)?; + let right = parse_comparison(tokens, pos, depth)?; left = SqlExpr::BinaryOp { left: Box::new(left), op: BinaryOp::And, @@ -228,8 +232,12 @@ fn parse_and(tokens: &[Token], pos: &mut usize) -> Result { Ok(left) } -fn parse_comparison(tokens: &[Token], pos: &mut usize) -> Result { - let left = parse_additive(tokens, pos)?; +fn parse_comparison( + tokens: &[Token], + pos: &mut usize, + depth: &mut usize, +) -> Result { + let left = parse_additive(tokens, pos, depth)?; if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op { let op = match tokens[*pos].text.as_str() { "=" => BinaryOp::Eq, @@ -241,7 +249,7 @@ fn parse_comparison(tokens: &[Token], pos: &mut usize) -> Result return Ok(left), }; *pos += 1; - let right = parse_additive(tokens, pos)?; + let right = parse_additive(tokens, pos, depth)?; return Ok(SqlExpr::BinaryOp { left: Box::new(left), op, @@ -251,8 +259,8 @@ fn parse_comparison(tokens: &[Token], pos: &mut usize) -> Result Result { - let mut left = parse_multiplicative(tokens, pos)?; +fn parse_additive(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result { + let mut left = parse_multiplicative(tokens, pos, depth)?; while *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op { let op = match tokens[*pos].text.as_str() { "+" => BinaryOp::Add, @@ -261,7 +269,7 @@ fn parse_additive(tokens: &[Token], pos: &mut usize) -> Result _ => break, }; *pos += 1; - let right = parse_multiplicative(tokens, pos)?; + let right = parse_multiplicative(tokens, pos, depth)?; left = SqlExpr::BinaryOp { left: Box::new(left), op, @@ -271,8 +279,12 @@ fn parse_additive(tokens: &[Token], pos: &mut usize) -> Result Ok(left) } -fn parse_multiplicative(tokens: &[Token], pos: &mut usize) -> Result { - let mut left = parse_unary(tokens, pos)?; +fn parse_multiplicative( + tokens: &[Token], + pos: &mut usize, + depth: &mut usize, +) -> Result { + let mut left = parse_unary(tokens, pos, depth)?; while *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op { let op = match tokens[*pos].text.as_str() { "*" => BinaryOp::Mul, @@ -281,7 +293,7 @@ fn parse_multiplicative(tokens: &[Token], pos: &mut usize) -> Result break, }; *pos += 1; - let right = parse_unary(tokens, pos)?; + let right = parse_unary(tokens, pos, depth)?; left = SqlExpr::BinaryOp { left: Box::new(left), op, @@ -291,23 +303,23 @@ fn parse_multiplicative(tokens: &[Token], pos: &mut usize) -> Result Result { +fn parse_unary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result { // Unary minus. if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op && tokens[*pos].text == "-" { *pos += 1; - let expr = parse_primary(tokens, pos)?; + let expr = parse_primary(tokens, pos, depth)?; return Ok(SqlExpr::Negate(Box::new(expr))); } // NOT if peek_keyword(tokens, *pos, "NOT") { *pos += 1; - let expr = parse_primary(tokens, pos)?; + let expr = parse_primary(tokens, pos, depth)?; return Ok(SqlExpr::Negate(Box::new(expr))); } - parse_primary(tokens, pos) + parse_primary(tokens, pos, depth) } -fn parse_primary(tokens: &[Token], pos: &mut usize) -> Result { +fn parse_primary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result { if *pos >= tokens.len() { return Err("unexpected end of expression".into()); } @@ -317,8 +329,15 @@ fn parse_primary(tokens: &[Token], pos: &mut usize) -> Result { match token.kind { // Parenthesized expression. TokenKind::LParen => { + *depth += 1; + if *depth > MAX_EXPR_DEPTH { + return Err(format!( + "expression nesting depth exceeds maximum of {MAX_EXPR_DEPTH}" + )); + } *pos += 1; - let expr = parse_expr(tokens, pos)?; + let expr = parse_expr(tokens, pos, depth)?; + *depth -= 1; expect_token(tokens, pos, TokenKind::RParen, ")")?; Ok(expr) } @@ -351,15 +370,15 @@ fn parse_primary(tokens: &[Token], pos: &mut usize) -> Result { "NULL" => Ok(SqlExpr::Literal(Value::Null)), "TRUE" => Ok(SqlExpr::Literal(Value::Bool(true))), "FALSE" => Ok(SqlExpr::Literal(Value::Bool(false))), - "CASE" => parse_case(tokens, pos), + "CASE" => parse_case(tokens, pos, depth), "COALESCE" => { - let args = parse_arg_list(tokens, pos)?; + let args = parse_arg_list(tokens, pos, depth)?; Ok(SqlExpr::Coalesce(args)) } _ => { // Function call: IDENT(args). if *pos < tokens.len() && tokens[*pos].kind == TokenKind::LParen { - let args = parse_arg_list(tokens, pos)?; + let args = parse_arg_list(tokens, pos, depth)?; Ok(SqlExpr::Function { name: name.to_lowercase(), args, @@ -377,20 +396,20 @@ fn parse_primary(tokens: &[Token], pos: &mut usize) -> Result { } /// Parse `CASE WHEN cond THEN result [WHEN ... THEN ...] [ELSE result] END`. -fn parse_case(tokens: &[Token], pos: &mut usize) -> Result { +fn parse_case(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result { let mut when_thens = Vec::new(); let mut else_expr = None; loop { if peek_keyword(tokens, *pos, "WHEN") { *pos += 1; - let cond = parse_expr(tokens, pos)?; + let cond = parse_expr(tokens, pos, depth)?; expect_keyword(tokens, pos, "THEN")?; - let then = parse_expr(tokens, pos)?; + let then = parse_expr(tokens, pos, depth)?; when_thens.push((cond, then)); } else if peek_keyword(tokens, *pos, "ELSE") { *pos += 1; - else_expr = Some(Box::new(parse_expr(tokens, pos)?)); + else_expr = Some(Box::new(parse_expr(tokens, pos, depth)?)); } else if peek_keyword(tokens, *pos, "END") { *pos += 1; break; @@ -411,7 +430,11 @@ fn parse_case(tokens: &[Token], pos: &mut usize) -> Result { } /// Parse a parenthesized, comma-separated argument list: `(expr, expr, ...)`. -fn parse_arg_list(tokens: &[Token], pos: &mut usize) -> Result, String> { +fn parse_arg_list( + tokens: &[Token], + pos: &mut usize, + depth: &mut usize, +) -> Result, String> { expect_token(tokens, pos, TokenKind::LParen, "(")?; let mut args = Vec::new(); if *pos < tokens.len() && tokens[*pos].kind == TokenKind::RParen { @@ -419,7 +442,7 @@ fn parse_arg_list(tokens: &[Token], pos: &mut usize) -> Result, Str return Ok(args); } loop { - args.push(parse_expr(tokens, pos)?); + args.push(parse_expr(tokens, pos, depth)?); if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Comma { *pos += 1; } else { @@ -667,4 +690,18 @@ mod tests { let doc = Value::from(serde_json::json!({"price": 49.99})); assert_eq!(expr.eval(&doc), Value::Float(49.99)); } + + #[test] + fn deeply_nested_parentheses_return_error_not_stack_overflow() { + // Spec: the parser must enforce a recursion depth limit so that + // pathologically deep nesting returns Err rather than overflowing the + // call stack and causing a process crash. + let depth = 10_000; + let input = format!("{}x{}", "(".repeat(depth), ")".repeat(depth),); + let result = parse_generated_expr(&input); + assert!( + result.is_err(), + "parse_generated_expr must return Err for {depth}-deep nesting, not stack overflow" + ); + } } diff --git a/nodedb-query/src/msgpack_scan/aggregate_helpers.rs b/nodedb-query/src/msgpack_scan/aggregate_helpers.rs new file mode 100644 index 00000000..10e33e61 --- /dev/null +++ b/nodedb-query/src/msgpack_scan/aggregate_helpers.rs @@ -0,0 +1,78 @@ +//! Public helpers for streaming aggregate accumulators. +//! +//! These thin wrappers expose field-extraction primitives used by the +//! `handlers/aggregate.rs` streaming accumulator path in the `nodedb` crate. +//! Each function operates on a single raw MessagePack document byte slice and +//! returns only the scalar value needed by the calling accumulator — no +//! document bytes are retained after the call returns. + +use nodedb_types::Value; + +use crate::expr::SqlExpr; +use crate::msgpack_scan::field::extract_field; +use crate::msgpack_scan::reader::{read_f64, read_str, read_value}; +use crate::value_ops; + +// ── Expression evaluator ─────────────────────────────────────────────────── + +#[inline] +fn eval_expr(doc: &[u8], expr: &SqlExpr) -> Option { + let doc_val = nodedb_types::json_msgpack::value_from_msgpack(doc).ok()?; + Some(expr.eval(&doc_val)) +} + +// ── Public extraction helpers ────────────────────────────────────────────── + +/// Extract a numeric (f64) value from `field`, or evaluate `expr` if provided. +/// Returns `None` when the field is absent or cannot be converted to f64. +#[inline] +pub fn extract_f64(doc: &[u8], field: &str, expr: Option<&SqlExpr>) -> Option { + if let Some(expr) = expr { + return value_ops::value_to_f64(&eval_expr(doc, expr)?, false); + } + let (start, _end) = extract_field(doc, 0, field)?; + read_f64(doc, start) +} + +/// Extract a display string from `field`, or evaluate `expr` if provided. +/// Returns `None` when the field is absent. +pub fn extract_str(doc: &[u8], field: &str, expr: Option<&SqlExpr>) -> Option { + if let Some(expr) = expr { + return Some(value_ops::value_to_display_string(&eval_expr(doc, expr)?)); + } + let (start, _end) = extract_field(doc, 0, field)?; + read_str(doc, start).map(|s| s.to_string()) +} + +/// Extract a field as `Value`. Uses direct msgpack→Value for scalars; +/// falls back to full document decode only for complex types. +pub fn extract_value(doc: &[u8], field: &str, expr: Option<&SqlExpr>) -> Option { + if let Some(expr) = expr { + return eval_expr(doc, expr); + } + let (start, end) = extract_field(doc, 0, field)?; + if let Some(v) = read_value(doc, start) { + return Some(v); + } + let field_bytes = &doc[start..end]; + nodedb_types::json_msgpack::value_from_msgpack(field_bytes).ok() +} + +/// Extract a field or expression result as raw msgpack bytes. +/// Used by `count_distinct`, `approx_count_distinct`, `approx_topk`, etc. +pub fn extract_bytes(doc: &[u8], field: &str, expr: Option<&SqlExpr>) -> Option> { + if let Some(expr) = expr { + let val = eval_expr(doc, expr)?; + return nodedb_types::json_msgpack::value_to_msgpack(&val).ok(); + } + let (start, end) = extract_field(doc, 0, field)?; + Some(doc[start..end].to_vec()) +} + +/// Returns `Some(())` when the field is present and non-null. +/// Used by `count(field)` accumulator to count non-null values. +#[inline] +pub fn extract_non_null(doc: &[u8], field: &str, expr: Option<&SqlExpr>) -> Option<()> { + let v = extract_value(doc, field, expr)?; + if v.is_null() { None } else { Some(()) } +} diff --git a/nodedb-query/src/msgpack_scan/mod.rs b/nodedb-query/src/msgpack_scan/mod.rs index 46b0f29e..6e5981c4 100644 --- a/nodedb-query/src/msgpack_scan/mod.rs +++ b/nodedb-query/src/msgpack_scan/mod.rs @@ -5,6 +5,7 @@ //! reads, comparisons, and hashing all work on raw byte offsets. pub mod aggregate; +pub mod aggregate_helpers; pub mod compare; pub mod field; pub mod filter; diff --git a/nodedb-sql/src/resolver/expr.rs b/nodedb-sql/src/resolver/expr.rs index 99ed7772..be2d3bb6 100644 --- a/nodedb-sql/src/resolver/expr.rs +++ b/nodedb-sql/src/resolver/expr.rs @@ -6,8 +6,30 @@ use crate::error::{Result, SqlError}; use crate::parser::normalize::normalize_ident; use crate::types::*; +/// Maximum AST nesting depth accepted by `convert_expr`. +/// Exceeding this limit returns `Err` instead of overflowing the stack. +const MAX_CONVERT_DEPTH: usize = 128; + /// Convert a sqlparser `Expr` to our `SqlExpr`. pub fn convert_expr(expr: &Expr) -> Result { + convert_expr_depth(expr, &mut 0) +} + +/// Internal recursive helper that carries a depth counter to enforce +/// `MAX_CONVERT_DEPTH` and prevent stack overflow on malformed ASTs. +fn convert_expr_depth(expr: &Expr, depth: &mut usize) -> Result { + *depth += 1; + if *depth > MAX_CONVERT_DEPTH { + return Err(SqlError::Unsupported { + detail: format!("expression nesting depth exceeds maximum of {MAX_CONVERT_DEPTH}"), + }); + } + let result = convert_expr_inner(expr, depth); + *depth -= 1; + result +} + +fn convert_expr_inner(expr: &Expr, depth: &mut usize) -> Result { match expr { Expr::Identifier(ident) => Ok(SqlExpr::Column { table: None, @@ -19,22 +41,22 @@ pub fn convert_expr(expr: &Expr) -> Result { }), Expr::Value(val) => Ok(SqlExpr::Literal(convert_value(&val.value)?)), Expr::BinaryOp { left, op, right } => Ok(SqlExpr::BinaryOp { - left: Box::new(convert_expr(left)?), + left: Box::new(convert_expr_depth(left, depth)?), op: convert_binary_op(op)?, - right: Box::new(convert_expr(right)?), + right: Box::new(convert_expr_depth(right, depth)?), }), Expr::UnaryOp { op, expr } => Ok(SqlExpr::UnaryOp { op: convert_unary_op(op)?, - expr: Box::new(convert_expr(expr)?), + expr: Box::new(convert_expr_depth(expr, depth)?), }), - Expr::Function(func) => convert_function(func), - Expr::Nested(inner) => convert_expr(inner), + Expr::Function(func) => convert_function_depth(func, depth), + Expr::Nested(inner) => convert_expr_depth(inner, depth), Expr::IsNull(inner) => Ok(SqlExpr::IsNull { - expr: Box::new(convert_expr(inner)?), + expr: Box::new(convert_expr_depth(inner, depth)?), negated: false, }), Expr::IsNotNull(inner) => Ok(SqlExpr::IsNull { - expr: Box::new(convert_expr(inner)?), + expr: Box::new(convert_expr_depth(inner, depth)?), negated: true, }), Expr::InList { @@ -42,8 +64,11 @@ pub fn convert_expr(expr: &Expr) -> Result { list, negated, } => Ok(SqlExpr::InList { - expr: Box::new(convert_expr(expr)?), - list: list.iter().map(convert_expr).collect::>()?, + expr: Box::new(convert_expr_depth(expr, depth)?), + list: list + .iter() + .map(|e| convert_expr_depth(e, depth)) + .collect::>()?, negated: *negated, }), Expr::Between { @@ -52,9 +77,9 @@ pub fn convert_expr(expr: &Expr) -> Result { high, negated, } => Ok(SqlExpr::Between { - expr: Box::new(convert_expr(expr)?), - low: Box::new(convert_expr(low)?), - high: Box::new(convert_expr(high)?), + expr: Box::new(convert_expr_depth(expr, depth)?), + low: Box::new(convert_expr_depth(low, depth)?), + high: Box::new(convert_expr_depth(high, depth)?), negated: *negated, }), Expr::Like { @@ -63,8 +88,8 @@ pub fn convert_expr(expr: &Expr) -> Result { negated, .. } => Ok(SqlExpr::Like { - expr: Box::new(convert_expr(expr)?), - pattern: Box::new(convert_expr(pattern)?), + expr: Box::new(convert_expr_depth(expr, depth)?), + pattern: Box::new(convert_expr_depth(pattern, depth)?), negated: *negated, }), Expr::ILike { @@ -73,8 +98,8 @@ pub fn convert_expr(expr: &Expr) -> Result { negated, .. } => Ok(SqlExpr::Like { - expr: Box::new(convert_expr(expr)?), - pattern: Box::new(convert_expr(pattern)?), + expr: Box::new(convert_expr_depth(expr, depth)?), + pattern: Box::new(convert_expr_depth(pattern, depth)?), negated: *negated, }), Expr::Case { @@ -85,46 +110,54 @@ pub fn convert_expr(expr: &Expr) -> Result { } => { let when_then = conditions .iter() - .map(|cw| Ok((convert_expr(&cw.condition)?, convert_expr(&cw.result)?))) + .map(|cw| { + Ok(( + convert_expr_depth(&cw.condition, depth)?, + convert_expr_depth(&cw.result, depth)?, + )) + }) .collect::>>()?; Ok(SqlExpr::Case { operand: operand .as_ref() - .map(|e| convert_expr(e).map(Box::new)) + .map(|e| convert_expr_depth(e, depth).map(Box::new)) .transpose()?, when_then, else_expr: else_result .as_ref() - .map(|e| convert_expr(e).map(Box::new)) + .map(|e| convert_expr_depth(e, depth).map(Box::new)) .transpose()?, }) } Expr::Cast { expr, data_type, .. } => Ok(SqlExpr::Cast { - expr: Box::new(convert_expr(expr)?), + expr: Box::new(convert_expr_depth(expr, depth)?), to_type: format!("{data_type}"), }), Expr::Array(ast::Array { elem, .. }) => { - let elems = elem.iter().map(convert_expr).collect::>()?; + let elems = elem + .iter() + .map(|e| convert_expr_depth(e, depth)) + .collect::>()?; Ok(SqlExpr::ArrayLiteral(elems)) } Expr::Wildcard(_) => Ok(SqlExpr::Wildcard), // TRIM([BOTH|LEADING|TRAILING] [what FROM] expr) Expr::Trim { expr, .. } => Ok(SqlExpr::Function { name: "trim".into(), - args: vec![convert_expr(expr)?], + args: vec![convert_expr_depth(expr, depth)?], distinct: false, }), // CEIL(expr) / FLOOR(expr) Expr::Ceil { expr, .. } => Ok(SqlExpr::Function { name: "ceil".into(), - args: vec![convert_expr(expr)?], + args: vec![convert_expr_depth(expr, depth)?], distinct: false, }), Expr::Floor { expr, .. } => Ok(SqlExpr::Function { name: "floor".into(), - args: vec![convert_expr(expr)?], + args: vec![convert_expr_depth(expr, depth)?], distinct: false, }), // SUBSTRING(expr FROM start FOR len) @@ -134,12 +167,12 @@ pub fn convert_expr(expr: &Expr) -> Result { substring_for, .. } => { - let mut args = vec![convert_expr(expr)?]; + let mut args = vec![convert_expr_depth(expr, depth)?]; if let Some(from) = substring_from { - args.push(convert_expr(from)?); + args.push(convert_expr_depth(from, depth)?); } if let Some(len) = substring_for { - args.push(convert_expr(len)?); + args.push(convert_expr_depth(len, depth)?); } Ok(SqlExpr::Function { name: "substring".into(), @@ -241,7 +274,7 @@ pub fn convert_value(val: &Value) -> Result { } } -fn convert_function(func: &ast::Function) -> Result { +fn convert_function_depth(func: &ast::Function, depth: &mut usize) -> Result { let name = func .name .0 @@ -264,14 +297,16 @@ fn convert_function(func: &ast::Function) -> Result { .args .iter() .filter_map(|a| match a { - ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => Some(convert_expr(e)), + ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => { + Some(convert_expr_depth(e, depth)) + } ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard) => { Some(Ok(SqlExpr::Wildcard)) } ast::FunctionArg::Named { arg: ast::FunctionArgExpr::Expr(e), .. - } => Some(convert_expr(e)), + } => Some(convert_expr_depth(e, depth)), _ => None, }) .collect::>>()?, diff --git a/nodedb-vector/src/collection/search.rs b/nodedb-vector/src/collection/search.rs index d3d2712e..060f5b31 100644 --- a/nodedb-vector/src/collection/search.rs +++ b/nodedb-vector/src/collection/search.rs @@ -1,6 +1,6 @@ //! VectorCollection search: multi-segment merging with SQ8 reranking. -use crate::distance::{DistanceMetric, distance}; +use crate::distance::distance; use crate::hnsw::SearchResult; use super::lifecycle::VectorCollection; @@ -19,33 +19,15 @@ impl VectorCollection { // Search sealed segments. for seg in &self.sealed { - let results = if let Some((codec, sq8_data)) = &seg.sq8 { - // Quantized two-phase search. + let results = if let Some(_sq8) = &seg.sq8 { + // Quantized two-phase search: use HNSW graph for O(log N) candidate + // generation, then rerank with exact FP32 distance. let rerank_k = top_k.saturating_mul(3).max(20); - let mut candidates: Vec<(u32, f32)> = Vec::with_capacity(seg.index.len()); - let dim = seg.index.dim(); - for i in 0..seg.index.len() { - if seg.index.is_deleted(i as u32) { - continue; - } - let sq8_vec = &sq8_data[i * dim..(i + 1) * dim]; - let d = match self.params.metric { - DistanceMetric::L2 => codec.asymmetric_l2(query, sq8_vec), - DistanceMetric::Cosine => codec.asymmetric_cosine(query, sq8_vec), - DistanceMetric::InnerProduct => codec.asymmetric_ip(query, sq8_vec), - _ => { - let dequant = codec.dequantize(sq8_vec); - distance(query, &dequant, self.params.metric) - } - }; - candidates.push((i as u32, d)); - } - if candidates.len() > rerank_k { - candidates.select_nth_unstable_by(rerank_k, |a, b| { - a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) - }); - candidates.truncate(rerank_k); - } + let hnsw_candidates = seg.index.search(query, rerank_k, ef); + let candidates: Vec<(u32, f32)> = hnsw_candidates + .into_iter() + .map(|r| (r.id, r.distance)) + .collect(); // Prefetch FP32 vectors for reranking candidates. if let Some(mmap) = &seg.mmap_vectors { @@ -252,4 +234,99 @@ mod tests { let results = coll.search(&[5.0, 0.0], 10, 64); assert!(results.iter().all(|r| r.id != 5)); } + + /// Build a sealed HNSW segment from `n` vectors of `dim=2`, where vector `i` + /// is `[i as f32, 0.0]`. Returns the collection with one sealed segment. + fn make_sealed_collection(n: usize) -> VectorCollection { + let mut coll = VectorCollection::new( + 2, + HnswParams { + metric: DistanceMetric::L2, + ..HnswParams::default() + }, + ); + for i in 0..n { + coll.insert(vec![i as f32, 0.0]); + } + let req = coll.seal("seg").unwrap(); + let mut idx = HnswIndex::new(req.dim, req.params); + for v in &req.vectors { + idx.insert(v.clone()).unwrap(); + } + coll.complete_build(req.segment_id, idx); + coll + } + + /// Attach SQ8 quantization to the first sealed segment of `coll`. + fn attach_sq8(coll: &mut VectorCollection) { + use crate::quantize::sq8::Sq8Codec; + + let sealed = &mut coll.sealed[0]; + let dim = sealed.index.dim(); + let n = sealed.index.len(); + let vecs: Vec> = (0..n) + .filter_map(|i| sealed.index.get_vector(i as u32).map(|v| v.to_vec())) + .collect(); + let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect(); + let codec = Sq8Codec::calibrate(&refs, dim); + let sq8_data: Vec = vecs.iter().flat_map(|v| codec.quantize(v)).collect(); + sealed.sq8 = Some((codec, sq8_data)); + } + + #[test] + fn sq8_search_returns_correct_nearest_neighbor() { + let mut coll = make_sealed_collection(200); + attach_sq8(&mut coll); + + let results = coll.search(&[100.0, 0.0], 5, 64); + assert!(!results.is_empty(), "expected non-empty results"); + assert_eq!( + results[0].id, 100, + "nearest neighbor of [100,0] should be id=100, got id={}", + results[0].id + ); + } + + #[test] + fn sq8_search_recall_matches_hnsw() { + // Build two identical collections — one without SQ8, one with. + let coll_plain = make_sealed_collection(500); + let mut coll_sq8 = make_sealed_collection(500); + attach_sq8(&mut coll_sq8); + + let query = [250.0f32, 0.0]; + let top_k = 5; + + let plain_results = coll_plain.search(&query, top_k, 64); + let sq8_results = coll_sq8.search(&query, top_k, 64); + + let plain_ids: std::collections::HashSet = + plain_results.iter().map(|r| r.id).collect(); + let sq8_ids: std::collections::HashSet = sq8_results.iter().map(|r| r.id).collect(); + + let overlap = plain_ids.intersection(&sq8_ids).count(); + assert!( + overlap >= 4, + "SQ8 recall too low: {overlap}/5 results matched plain HNSW (need >=4)" + ); + } + + #[test] + fn sq8_search_does_not_scan_all_vectors() { + // This test validates correctness of the SQ8 search path for a large + // segment. The bug being guarded against is an O(N) linear scan instead + // of graph-guided traversal: the fix must use HNSW with SQ8 as the + // distance function. Correctness (correct nearest neighbor) is the + // invariant that must be preserved when the implementation changes. + let mut coll = make_sealed_collection(2000); + attach_sq8(&mut coll); + + let results = coll.search(&[1000.0, 0.0], 5, 64); + assert!(!results.is_empty(), "expected non-empty results"); + assert_eq!( + results[0].id, 1000, + "nearest neighbor of [1000,0] should be id=1000, got id={}", + results[0].id + ); + } } diff --git a/nodedb-vector/src/hnsw/search.rs b/nodedb-vector/src/hnsw/search.rs index d59a5e22..c917943f 100644 --- a/nodedb-vector/src/hnsw/search.rs +++ b/nodedb-vector/src/hnsw/search.rs @@ -21,7 +21,9 @@ impl HnswIndex { return Vec::new(); } - let ef = ef.max(k); + /// Maximum beam width to prevent runaway search cost. + const MAX_EF: usize = 8192; + let ef = ef.max(k).min(MAX_EF); let Some(ep) = self.entry_point else { return Vec::new(); }; diff --git a/nodedb-vector/src/mmap_segment.rs b/nodedb-vector/src/mmap_segment.rs index ef2d235f..0c8947fb 100644 --- a/nodedb-vector/src/mmap_segment.rs +++ b/nodedb-vector/src/mmap_segment.rs @@ -92,7 +92,35 @@ impl MmapVectorSegment { u32::from_le(*ptr) as usize }; - let expected = HEADER_SIZE + count * dim * 4; + // Reject dim=0 with nonzero count: get_vector would compute offset=HEADER_SIZE + // for every ID, aliasing header bytes as vector data. + if dim == 0 && count > 0 { + unsafe { + libc::munmap(base as *mut libc::c_void, file_size); + } + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "mmap segment has dim=0 with nonzero count", + )); + } + + // Use checked arithmetic to prevent usize overflow on crafted headers. + let data_bytes = dim + .checked_mul(count) + .and_then(|dc| dc.checked_mul(4)) + .and_then(|bytes| bytes.checked_add(HEADER_SIZE)); + let expected = match data_bytes { + Some(v) => v, + None => { + unsafe { + libc::munmap(base as *mut libc::c_void, file_size); + } + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("mmap segment header overflow: dim={dim}, count={count}"), + )); + } + }; if file_size < expected { unsafe { libc::munmap(base as *mut libc::c_void, file_size); @@ -121,7 +149,12 @@ impl MmapVectorSegment { if idx >= self.count { return None; } - let offset = self.data_offset + idx * self.dim * 4; + let byte_len = self.dim.checked_mul(4)?; + let offset = self.data_offset.checked_add(idx.checked_mul(byte_len)?)?; + let end = offset.checked_add(byte_len)?; + if end > self.mmap_size { + return None; + } unsafe { let ptr = self.base.add(offset) as *const f32; Some(std::slice::from_raw_parts(ptr, self.dim)) @@ -134,9 +167,24 @@ impl MmapVectorSegment { if idx >= self.count { return; } - let offset = self.data_offset + idx * self.dim * 4; + let byte_len = match self.dim.checked_mul(4) { + Some(v) => v, + None => return, + }; + let Some(idx_bytes) = idx.checked_mul(byte_len) else { + return; + }; + let Some(offset) = self.data_offset.checked_add(idx_bytes) else { + return; + }; + if offset + .checked_add(byte_len) + .is_none_or(|e| e > self.mmap_size) + { + return; + } let page_start = offset & !(4095); - let len = (self.dim * 4 + 4095) & !(4095); + let len = (byte_len + 4095) & !(4095); unsafe { libc::madvise( self.base.add(page_start) as *mut libc::c_void, @@ -249,4 +297,99 @@ mod tests { assert_eq!(seg.count(), 0); assert!(seg.get_vector(0).is_none()); } + + #[test] + fn overflow_dim_count_rejected() { + use std::io::Write; + + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("overflow.vseg"); + + // dim=0x40000001, count=0x40000001: count * dim * 4 overflows usize on 64-bit + // (0x40000001 * 0x40000001 * 4 = 0x4000000280000004, which wraps to a small value). + let dim: u32 = 0x40000001; + let count: u32 = 0x40000001; + + let mut f = std::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(&path) + .unwrap(); + f.write_all(&dim.to_le_bytes()).unwrap(); + f.write_all(&count.to_le_bytes()).unwrap(); + // No actual vector data — just a 8-byte header. + drop(f); + + let result = MmapVectorSegment::open(&path); + assert!( + result.is_err(), + "expected Err for overflow-inducing dim/count, got Ok" + ); + } + + #[test] + fn truncated_file_rejected() { + use std::io::Write; + + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("truncated.vseg"); + + // Header claims dim=3, count=100 but only 8 bytes of actual data. + let dim: u32 = 3; + let count: u32 = 100; + + let mut f = std::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(&path) + .unwrap(); + f.write_all(&dim.to_le_bytes()).unwrap(); + f.write_all(&count.to_le_bytes()).unwrap(); + drop(f); + + let result = MmapVectorSegment::open(&path); + match result { + Err(e) => assert_eq!( + e.kind(), + std::io::ErrorKind::InvalidData, + "expected InvalidData, got {:?}", + e.kind() + ), + Ok(_) => panic!("expected Err for truncated file, got Ok"), + } + } + + #[test] + fn zero_dim_with_nonzero_count_rejected() { + use std::io::Write; + + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("zerodim.vseg"); + + // dim=0, count=1000: expected size = HEADER_SIZE + 0 = 8, so the size + // check passes, but get_vector would read header bytes as vector data. + // dim=0 must be rejected outright. + let dim: u32 = 0; + let count: u32 = 1000; + + let mut f = std::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(&path) + .unwrap(); + f.write_all(&dim.to_le_bytes()).unwrap(); + f.write_all(&count.to_le_bytes()).unwrap(); + // Write enough padding so the file passes a naive size check. + f.write_all(&[0u8; 64]).unwrap(); + drop(f); + + let result = MmapVectorSegment::open(&path); + assert!( + result.is_err(), + "expected Err for dim=0 with nonzero count, got Ok" + ); + } } diff --git a/nodedb-wal/Cargo.toml b/nodedb-wal/Cargo.toml index d21beb00..2171ca13 100644 --- a/nodedb-wal/Cargo.toml +++ b/nodedb-wal/Cargo.toml @@ -21,6 +21,7 @@ libc = { workspace = true } memmap2 = { workspace = true } io-uring = { workspace = true, optional = true } aes-gcm = { workspace = true } +getrandom = { workspace = true } [dev-dependencies] tokio = { workspace = true } diff --git a/nodedb-wal/src/crypto.rs b/nodedb-wal/src/crypto.rs index 75ec48b5..9bf525cd 100644 --- a/nodedb-wal/src/crypto.rs +++ b/nodedb-wal/src/crypto.rs @@ -4,7 +4,8 @@ //! - Header stays plaintext (needed for recovery scanning — magic, lsn, tenant_id) //! - Payload is encrypted before CRC computation //! - CRC covers the ciphertext (detects corruption of encrypted data) -//! - Nonce derived from LSN (deterministic — no extra storage, enables replay) +//! - Nonce = `[4-byte random epoch][8-byte LSN]` — epoch is generated per WAL +//! lifetime to prevent nonce reuse after snapshot restore or WAL truncation //! - Additional Authenticated Data (AAD) = header bytes (binds ciphertext to its header) //! //! On-disk format for encrypted payload: @@ -19,17 +20,27 @@ use aes_gcm::aead::{Aead, KeyInit}; use crate::error::{Result, WalError}; use crate::record::HEADER_SIZE; -/// AES-256-GCM key: exactly 32 bytes. +/// AES-256-GCM key with a random per-lifetime epoch for nonce disambiguation. +/// +/// The epoch is generated randomly at construction time. Each WAL lifetime +/// (process start, snapshot restore, segment creation) gets a fresh epoch, +/// ensuring that nonces are never reused even if LSNs restart from 1. #[derive(Clone)] pub struct WalEncryptionKey { cipher: Aes256Gcm, + /// Random 4-byte epoch: occupies the high 4 bytes of the 12-byte nonce. + /// Disambiguates nonces across WAL lifetimes with the same key. + epoch: [u8; 4], } impl WalEncryptionKey { - /// Create from a 32-byte key. + /// Create from a 32-byte key with a fresh random epoch. pub fn from_bytes(key: &[u8; 32]) -> Self { + let mut epoch = [0u8; 4]; + getrandom::fill(&mut epoch).expect("getrandom failed"); Self { cipher: Aes256Gcm::new(key.into()), + epoch, } } @@ -60,7 +71,7 @@ impl WalEncryptionKey { header_bytes: &[u8; HEADER_SIZE], plaintext: &[u8], ) -> Result> { - let nonce = lsn_to_nonce(lsn); + let nonce = lsn_to_nonce(&self.epoch, lsn); self.cipher .encrypt( &nonce, @@ -74,18 +85,25 @@ impl WalEncryptionKey { }) } + /// The random epoch for this key instance. + pub fn epoch(&self) -> &[u8; 4] { + &self.epoch + } + /// Decrypt a payload. Input is ciphertext + auth_tag (16 bytes at end). /// + /// - `epoch`: the epoch that was used during encryption (from the segment header) /// - `lsn`: must match the LSN used during encryption /// - `header_bytes`: must match the header used during encryption (AAD) /// - `ciphertext`: the encrypted payload (includes 16-byte auth tag) pub fn decrypt( &self, + epoch: &[u8; 4], lsn: u64, header_bytes: &[u8; HEADER_SIZE], ciphertext: &[u8], ) -> Result> { - let nonce = lsn_to_nonce(lsn); + let nonce = lsn_to_nonce(epoch, lsn); self.cipher .decrypt( &nonce, @@ -140,27 +158,23 @@ impl KeyRing { /// Decrypt: try current key first, then previous (if set). /// + /// `epoch` is the encryption epoch stored in the WAL segment header. /// This enables seamless key rotation — old data encrypted with the /// previous key can still be read while new data uses the current key. pub fn decrypt( &self, + epoch: &[u8; 4], lsn: u64, header_bytes: &[u8; HEADER_SIZE], ciphertext: &[u8], ) -> Result> { - match self.current.decrypt(lsn, header_bytes, ciphertext) { - Ok(plaintext) => Ok(plaintext), - Err(_) if self.previous.is_some() => { - // Current key failed — try previous key. - if let Some(prev) = self.previous.as_ref() { - prev.decrypt(lsn, header_bytes, ciphertext) - } else { - Err(crate::error::WalError::EncryptionError { - detail: "key rotation state inconsistent".into(), - }) - } - } - Err(e) => Err(e), + match ( + self.current.decrypt(epoch, lsn, header_bytes, ciphertext), + self.previous.as_ref(), + ) { + (Ok(plaintext), _) => Ok(plaintext), + (Err(_), Some(prev)) => prev.decrypt(epoch, lsn, header_bytes, ciphertext), + (Err(e), None) => Err(e), } } @@ -183,14 +197,16 @@ impl KeyRing { /// AES-256-GCM auth tag size in bytes. pub const AUTH_TAG_SIZE: usize = 16; -/// Derive a 12-byte nonce from an LSN. +/// Derive a 12-byte nonce from an epoch and LSN. /// -/// AES-256-GCM requires a 96-bit (12 byte) nonce. Since LSNs are monotonically -/// increasing and globally unique, they make ideal deterministic nonces. -/// We zero-pad the 8-byte LSN to 12 bytes. -fn lsn_to_nonce(lsn: u64) -> aes_gcm::Nonce { +/// AES-256-GCM requires a 96-bit (12 byte) nonce that must never repeat +/// for the same key. Layout: `[4-byte random epoch][8-byte LSN]`. +/// The epoch is generated randomly per WAL lifetime, so even if LSNs +/// restart from 1 after a snapshot restore, the nonces remain unique. +fn lsn_to_nonce(epoch: &[u8; 4], lsn: u64) -> aes_gcm::Nonce { let mut nonce_bytes = [0u8; 12]; - nonce_bytes[..8].copy_from_slice(&lsn.to_le_bytes()); + nonce_bytes[..4].copy_from_slice(epoch); + nonce_bytes[4..12].copy_from_slice(&lsn.to_le_bytes()); nonce_bytes.into() } @@ -211,6 +227,7 @@ mod tests { #[test] fn encrypt_decrypt_roundtrip() { let key = test_key(); + let epoch = *key.epoch(); let header = test_header(1); let plaintext = b"hello nodedb encryption"; @@ -218,43 +235,47 @@ mod tests { assert_ne!(&ciphertext[..plaintext.len()], plaintext); assert_eq!(ciphertext.len(), plaintext.len() + AUTH_TAG_SIZE); - let decrypted = key.decrypt(1, &header, &ciphertext).unwrap(); + let decrypted = key.decrypt(&epoch, 1, &header, &ciphertext).unwrap(); assert_eq!(decrypted, plaintext); } #[test] fn wrong_key_fails() { let key1 = WalEncryptionKey::from_bytes(&[0x01; 32]); + let epoch1 = *key1.epoch(); let key2 = WalEncryptionKey::from_bytes(&[0x02; 32]); let header = test_header(1); let ciphertext = key1.encrypt(1, &header, b"secret").unwrap(); - assert!(key2.decrypt(1, &header, &ciphertext).is_err()); + assert!(key2.decrypt(&epoch1, 1, &header, &ciphertext).is_err()); } #[test] fn wrong_lsn_fails() { let key = test_key(); + let epoch = *key.epoch(); let header = test_header(1); let ciphertext = key.encrypt(1, &header, b"secret").unwrap(); // Different LSN = different nonce = decryption fails. - assert!(key.decrypt(2, &header, &ciphertext).is_err()); + assert!(key.decrypt(&epoch, 2, &header, &ciphertext).is_err()); } #[test] fn tampered_ciphertext_fails() { let key = test_key(); + let epoch = *key.epoch(); let header = test_header(1); let mut ciphertext = key.encrypt(1, &header, b"secret").unwrap(); ciphertext[0] ^= 0xFF; - assert!(key.decrypt(1, &header, &ciphertext).is_err()); + assert!(key.decrypt(&epoch, 1, &header, &ciphertext).is_err()); } #[test] fn tampered_header_fails() { let key = test_key(); + let epoch = *key.epoch(); let header1 = test_header(1); let ciphertext = key.encrypt(1, &header1, b"secret").unwrap(); @@ -262,18 +283,19 @@ mod tests { // Tamper the AAD (header). let mut header2 = header1; header2[0] = 0xFF; - assert!(key.decrypt(1, &header2, &ciphertext).is_err()); + assert!(key.decrypt(&epoch, 1, &header2, &ciphertext).is_err()); } #[test] fn empty_payload() { let key = test_key(); + let epoch = *key.epoch(); let header = test_header(1); let ciphertext = key.encrypt(1, &header, b"").unwrap(); assert_eq!(ciphertext.len(), AUTH_TAG_SIZE); // Just the tag. - let decrypted = key.decrypt(1, &header, &ciphertext).unwrap(); + let decrypted = key.decrypt(&epoch, 1, &header, &ciphertext).unwrap(); assert!(decrypted.is_empty()); } @@ -286,4 +308,27 @@ mod tests { let ct2 = key.encrypt(2, &test_header(2), plaintext).unwrap(); assert_ne!(ct1, ct2); } + + #[test] + fn same_lsn_different_wal_lifetimes_produce_different_ciphertext() { + // Simulate two WAL lifetimes: same key bytes, same LSN=1, but + // separate WalEncryptionKey instances (each gets a fresh random epoch). + // This models: write at LSN=1, wipe WAL, restart with same key, + // write at LSN=1 again. The two ciphertexts must differ. + let key_bytes = [0x42u8; 32]; + let key1 = WalEncryptionKey::from_bytes(&key_bytes); + let key2 = WalEncryptionKey::from_bytes(&key_bytes); + let header = test_header(1); + let pt = b"same plaintext in two wal lifetimes"; + + let ct1 = key1.encrypt(1, &header, pt).unwrap(); + let ct2 = key2.encrypt(1, &header, pt).unwrap(); + + // SPEC: different WAL lifetimes (different epochs) must produce + // different ciphertext even with the same key bytes and LSN. + assert_ne!( + ct1, ct2, + "nonce reuse: same (key_bytes, lsn) must not produce identical ciphertext across WAL lifetimes" + ); + } } diff --git a/nodedb-wal/src/mmap_reader.rs b/nodedb-wal/src/mmap_reader.rs index 2fe4160c..5b6424d8 100644 --- a/nodedb-wal/src/mmap_reader.rs +++ b/nodedb-wal/src/mmap_reader.rs @@ -49,65 +49,67 @@ impl MmapWalReader { pub fn next_record(&mut self) -> Result> { let data = &self.mmap[..]; - // Check if we have enough bytes for a header. - if self.offset + HEADER_SIZE > data.len() { - return Ok(None); - } + loop { + // Check if we have enough bytes for a header. + if self.offset + HEADER_SIZE > data.len() { + return Ok(None); + } - // Parse header. - let header_bytes: &[u8; HEADER_SIZE] = data[self.offset..self.offset + HEADER_SIZE] - .try_into() - .map_err(|_| { - WalError::Io(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "header slice conversion failed", - )) - })?; - let header = RecordHeader::from_bytes(header_bytes); - - // Validate magic — corruption or end of valid data. - if header.magic != WAL_MAGIC { - return Ok(None); - } + // Parse header. + let header_bytes: &[u8; HEADER_SIZE] = data[self.offset..self.offset + HEADER_SIZE] + .try_into() + .map_err(|_| { + WalError::Io(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "header slice conversion failed", + )) + })?; + let header = RecordHeader::from_bytes(header_bytes); + + // Validate magic — corruption or end of valid data. + if header.magic != WAL_MAGIC { + return Ok(None); + } - // Validate version. - if header.validate(self.offset as u64).is_err() { - return Ok(None); - } + // Validate version. + if header.validate(self.offset as u64).is_err() { + return Ok(None); + } - let payload_len = header.payload_len as usize; - let record_end = self.offset + HEADER_SIZE + payload_len; + let payload_len = header.payload_len as usize; + let record_end = self.offset + HEADER_SIZE + payload_len; - // Check if payload is fully within the mmap'd region. - if record_end > data.len() { - return Ok(None); // Torn write at segment end. - } + // Check if payload is fully within the mmap'd region. + if record_end > data.len() { + return Ok(None); // Torn write at segment end. + } - // Extract payload (copies from mmap to owned Vec). - let payload = data[self.offset + HEADER_SIZE..record_end].to_vec(); - self.offset = record_end; + // Extract payload (copies from mmap to owned Vec). + let payload = data[self.offset + HEADER_SIZE..record_end].to_vec(); + self.offset = record_end; - let record = WalRecord { header, payload }; + let record = WalRecord { header, payload }; - // Verify checksum. - if record.verify_checksum().is_err() { - return Ok(None); // Corruption — end of committed prefix. - } + // Verify checksum. + if record.verify_checksum().is_err() { + return Ok(None); // Corruption — end of committed prefix. + } - // Check record type. - let logical_type = record.logical_record_type(); - if RecordType::from_raw(logical_type).is_none() { - if RecordType::is_required(logical_type) { - return Err(WalError::UnknownRequiredRecordType { - record_type: header.record_type, - lsn: header.lsn, - }); + // Check record type. + let logical_type = record.logical_record_type(); + if RecordType::from_raw(logical_type).is_none() { + if RecordType::is_required(logical_type) { + return Err(WalError::UnknownRequiredRecordType { + record_type: header.record_type, + lsn: header.lsn, + }); + } + // Unknown optional record — skip and continue loop. + continue; } - // Unknown optional record — skip and continue. - return self.next_record(); - } - Ok(Some(record)) + return Ok(Some(record)); + } } /// Iterator over all valid records in the mmap'd segment. diff --git a/nodedb-wal/src/reader.rs b/nodedb-wal/src/reader.rs index 5936976b..b1914f41 100644 --- a/nodedb-wal/src/reader.rs +++ b/nodedb-wal/src/reader.rs @@ -53,72 +53,69 @@ impl WalReader { /// Returns `None` at EOF (clean end) or at the first corruption point. /// Returns `Err` only for I/O errors or unknown required record types. pub fn next_record(&mut self) -> Result> { - // Read header. - let mut header_buf = [0u8; HEADER_SIZE]; - match self.read_exact(&mut header_buf) { - Ok(()) => {} - Err(WalError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { - return Ok(None); // Clean EOF. + loop { + // Read header. + let mut header_buf = [0u8; HEADER_SIZE]; + match self.read_exact(&mut header_buf) { + Ok(()) => {} + Err(WalError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + return Ok(None); // Clean EOF. + } + Err(e) => return Err(e), } - Err(e) => return Err(e), - } - let header = RecordHeader::from_bytes(&header_buf); + let header = RecordHeader::from_bytes(&header_buf); - // Validate magic and version. - if header.validate(self.offset - HEADER_SIZE as u64).is_err() { - // Corruption or end of valid data — treat as end of committed prefix. - return Ok(None); - } + // Validate magic and version. + if header.validate(self.offset - HEADER_SIZE as u64).is_err() { + // Corruption or end of valid data — treat as end of committed prefix. + return Ok(None); + } - // Read payload. - let mut payload = vec![0u8; header.payload_len as usize]; - if !payload.is_empty() { - match self.read_exact(&mut payload) { - Ok(()) => {} - Err(WalError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { - // Torn write — record is incomplete. This is the end of committed prefix. - return Ok(None); + // Read payload. + let mut payload = vec![0u8; header.payload_len as usize]; + if !payload.is_empty() { + match self.read_exact(&mut payload) { + Ok(()) => {} + Err(WalError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + return Ok(None); + } + Err(e) => return Err(e), } - Err(e) => return Err(e), } - } - let record = WalRecord { header, payload }; - - // Verify checksum. - if record.verify_checksum().is_err() { - // Checksum mismatch — torn write or corruption. - // Try to recover from double-write buffer if available. - if let Some(dwb) = &mut self.double_write - && let Ok(Some(recovered)) = dwb.recover_record(header.lsn) - { - tracing::info!( - lsn = header.lsn, - "recovered torn write from double-write buffer" - ); - self.offset += recovered.payload.len() as u64; - return Ok(Some(recovered)); + let record = WalRecord { header, payload }; + + // Verify checksum. + if record.verify_checksum().is_err() { + if let Some(dwb) = &mut self.double_write + && let Ok(Some(recovered)) = dwb.recover_record(header.lsn) + { + tracing::info!( + lsn = header.lsn, + "recovered torn write from double-write buffer" + ); + self.offset += recovered.payload.len() as u64; + return Ok(Some(recovered)); + } + return Ok(None); } - // No DWB recovery possible — end of committed prefix. - return Ok(None); - } - // Check if the record type is known (strip encrypted flag for lookup). - let logical_type = record.logical_record_type(); - if RecordType::from_raw(logical_type).is_none() { - if RecordType::is_required(logical_type) { - return Err(WalError::UnknownRequiredRecordType { - record_type: header.record_type, - lsn: header.lsn, - }); + // Check if the record type is known (strip encrypted flag for lookup). + let logical_type = record.logical_record_type(); + if RecordType::from_raw(logical_type).is_none() { + if RecordType::is_required(logical_type) { + return Err(WalError::UnknownRequiredRecordType { + record_type: header.record_type, + lsn: header.lsn, + }); + } + // Unknown optional record — skip and continue loop. + continue; } - // Unknown optional record — skip it and continue. - // (The record is already consumed, so just recurse.) - return self.next_record(); - } - Ok(Some(record)) + return Ok(Some(record)); + } } /// Iterator over all valid records in the WAL. @@ -246,4 +243,39 @@ mod tests { assert_eq!(records.len(), 1); assert_eq!(records[0].payload, b"good-record"); } + + #[test] + fn skip_many_unknown_optional_records_is_iterative() { + // Record type 99 has bit 15 clear (99 & 0x8000 == 0) and is not a + // known variant, so the reader must skip it as an unknown optional. + // With the current recursive implementation (line 118: `return + // self.next_record()`), 50 000 consecutive unknown optional records + // exhaust the stack and panic. After the fix converts the skip to a + // loop, all 50 000 are skipped without overflow and the one valid + // record at the end is returned. + const UNKNOWN_OPTIONAL: u16 = 99; // no 0x8000 bit → optional, not in enum + const SKIP_COUNT: usize = 50_000; + + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("many_unknown.wal"); + + { + let mut writer = WalWriter::open_without_direct_io(&path).unwrap(); + for _ in 0..SKIP_COUNT { + writer.append(UNKNOWN_OPTIONAL, 1, 0, b"skip-me").unwrap(); + } + writer + .append(RecordType::Put as u16, 1, 0, b"keep-me") + .unwrap(); + writer.sync().unwrap(); + } + + let reader = WalReader::open(&path).unwrap(); + let records: Vec<_> = reader.records().collect::>().unwrap(); + + // Only the single known Put record survives; all unknown optional + // records are silently discarded. + assert_eq!(records.len(), 1); + assert_eq!(records[0].payload, b"keep-me"); + } } diff --git a/nodedb-wal/src/record.rs b/nodedb-wal/src/record.rs index b89fa0bf..7431cb3e 100644 --- a/nodedb-wal/src/record.rs +++ b/nodedb-wal/src/record.rs @@ -263,9 +263,11 @@ impl WalRecord { /// Decrypt the payload if the record is encrypted. /// + /// `epoch` is the encryption epoch from the WAL segment header. /// Returns the plaintext payload. If not encrypted, returns the payload as-is. pub fn decrypt_payload( &self, + epoch: &[u8; 4], encryption_key: Option<&crate::crypto::WalEncryptionKey>, ) -> Result> { if !self.is_encrypted() { @@ -284,14 +286,19 @@ impl WalRecord { aad_header.crc32c = 0; let header_bytes = aad_header.to_bytes(); - key.decrypt(self.header.lsn, &header_bytes, &self.payload) + key.decrypt(epoch, self.header.lsn, &header_bytes, &self.payload) } /// Decrypt the payload using a key ring (supports dual-key rotation). /// + /// `epoch` is the encryption epoch from the WAL segment header. /// Tries the current key first, then falls back to the previous key. /// Returns the plaintext payload. If not encrypted, returns the payload as-is. - pub fn decrypt_payload_ring(&self, ring: Option<&crate::crypto::KeyRing>) -> Result> { + pub fn decrypt_payload_ring( + &self, + epoch: &[u8; 4], + ring: Option<&crate::crypto::KeyRing>, + ) -> Result> { if !self.is_encrypted() { return Ok(self.payload.clone()); } @@ -306,7 +313,7 @@ impl WalRecord { aad_header.crc32c = 0; let header_bytes = aad_header.to_bytes(); - ring.decrypt(self.header.lsn, &header_bytes, &self.payload) + ring.decrypt(epoch, self.header.lsn, &header_bytes, &self.payload) } /// Whether this record's payload is encrypted. diff --git a/nodedb-wal/src/writer.rs b/nodedb-wal/src/writer.rs index 57e03721..ec541487 100644 --- a/nodedb-wal/src/writer.rs +++ b/nodedb-wal/src/writer.rs @@ -316,25 +316,32 @@ impl WalWriter { self.buffer.as_slice() }; - // Use pwrite to write at the exact offset. + // Use pwrite to write at the exact offset, retrying on short writes. #[cfg(unix)] { use std::os::unix::io::AsRawFd; let fd = self.file.as_raw_fd(); - let written = unsafe { - libc::pwrite( - fd, - data.as_ptr() as *const libc::c_void, - data.len(), - self.file_offset as libc::off_t, - ) - }; - if written < 0 { - return Err(WalError::Io(std::io::Error::last_os_error())); + let mut remaining = data; + let mut write_offset = self.file_offset; + while !remaining.is_empty() { + let written = unsafe { + libc::pwrite( + fd, + remaining.as_ptr() as *const libc::c_void, + remaining.len(), + write_offset as libc::off_t, + ) + }; + if written < 0 { + return Err(WalError::Io(std::io::Error::last_os_error())); + } + let n = written as usize; + remaining = &remaining[n..]; + write_offset += n as u64; } } - self.file_offset += self.buffer.len() as u64; + self.file_offset += data.len() as u64; self.buffer.clear(); Ok(()) } diff --git a/nodedb/src/control/server/ilp_listener.rs b/nodedb/src/control/server/ilp_listener.rs index 26dddd53..0d9c8079 100644 --- a/nodedb/src/control/server/ilp_listener.rs +++ b/nodedb/src/control/server/ilp_listener.rs @@ -12,6 +12,10 @@ use std::sync::Arc; use sonic_rs; use tokio::io::{AsyncBufReadExt, BufReader}; + +/// Maximum byte length of a single ILP line. Lines exceeding this are +/// rejected and the connection is dropped to prevent memory exhaustion. +const MAX_ILP_LINE_BYTES: usize = 10 * 1024 * 1024; // 10 MiB use tokio::net::TcpListener; use tokio::sync::Semaphore; use tracing::{debug, info, warn}; @@ -96,16 +100,24 @@ impl IlpListener { if let Some(ref acceptor) = tls_acceptor { let acceptor = acceptor.clone(); connections.spawn(async move { - match acceptor.accept(stream).await { - Ok(tls_stream) => { + match tokio::time::timeout( + std::time::Duration::from_secs(10), + acceptor.accept(stream), + ) + .await + { + Ok(Ok(tls_stream)) => { let cs = ConnStream::tls(tls_stream); if let Err(e) = handle_ilp_connection(cs, peer, &state).await { warn!(%peer, error = %e, "ILP TLS connection error (data may be lost)"); } } - Err(e) => { + Ok(Err(e)) => { warn!(%peer, error = %e, "ILP TLS handshake failed"); } + Err(_) => { + warn!(%peer, "ILP TLS handshake timed out"); + } } drop(permit); }); @@ -158,8 +170,8 @@ async fn handle_ilp_connection( ) -> crate::Result<()> { debug!(%peer, "ILP connection accepted"); - let reader = BufReader::new(stream); - let mut lines = reader.lines(); + let mut reader = BufReader::new(stream); + let mut line_buf: Vec = Vec::with_capacity(4096); let mut batch = String::new(); let mut line_count = 0u64; let mut total_ingested = 0u64; @@ -181,17 +193,46 @@ async fn handle_ilp_connection( loop { tokio::select! { - // Read next line. - result = lines.next_line() => { + // Read next line with an enforced byte-length cap. + result = reader.read_until(b'\n', &mut line_buf) => { match result { - Ok(Some(line)) => { + Ok(0) => break, // Connection closed (EOF). + Ok(_) => { + // Enforce line length limit before any allocation. + if line_buf.len() > MAX_ILP_LINE_BYTES { + warn!( + %peer, + len = line_buf.len(), + limit = MAX_ILP_LINE_BYTES, + "ILP line exceeds maximum length — dropping connection" + ); + break; + } + + // Strip trailing newline / CRLF. + let line_bytes = line_buf + .strip_suffix(b"\r\n") + .or_else(|| line_buf.strip_suffix(b"\n")) + .unwrap_or(&line_buf); + + let line = match std::str::from_utf8(line_bytes) { + Ok(s) => s, + Err(_) => { + warn!(%peer, "ILP line is not valid UTF-8 — skipping"); + line_buf.clear(); + continue; + } + }; + if line.is_empty() || line.starts_with('#') { + line_buf.clear(); continue; } - batch.push_str(&line); + batch.push_str(line); batch.push('\n'); line_count += 1; + line_buf.clear(); // Flush when batch reaches adaptive target. if line_count >= batch_target { @@ -212,8 +253,7 @@ async fn handle_ilp_connection( ); } } - Ok(None) => break, // Connection closed. - Err(_) => break, // Read error. + Err(_) => break, // Read error. } } // Timer-based flush (for low-rate connections). diff --git a/nodedb/src/control/server/listener.rs b/nodedb/src/control/server/listener.rs index a1424c96..19358372 100644 --- a/nodedb/src/control/server/listener.rs +++ b/nodedb/src/control/server/listener.rs @@ -120,16 +120,24 @@ impl Listener { if let Some(ref acceptor) = tls_acceptor { let acceptor = acceptor.clone(); connections.spawn(async move { - match acceptor.accept(stream).await { - Ok(tls_stream) => { + match tokio::time::timeout( + Duration::from_secs(10), + acceptor.accept(stream), + ) + .await + { + Ok(Ok(tls_stream)) => { let session = NativeSession::new_tls(tls_stream, peer_addr, state_clone, mode); if let Err(e) = session.run().await { warn!(%peer_addr, error = %e, "TLS session terminated with error"); } } - Err(e) => { + Ok(Err(e)) => { warn!(%peer_addr, error = %e, "native TLS handshake failed"); } + Err(_) => { + warn!(%peer_addr, "native TLS handshake timed out"); + } } // Permit is held for the session's lifetime and // released on drop when this future completes. diff --git a/nodedb/src/control/server/pgwire/ddl/backup.rs b/nodedb/src/control/server/pgwire/ddl/backup.rs index d96d6979..3e496062 100644 --- a/nodedb/src/control/server/pgwire/ddl/backup.rs +++ b/nodedb/src/control/server/pgwire/ddl/backup.rs @@ -244,7 +244,7 @@ pub async fn restore_tenant( })?; let mut aad = [0u8; nodedb_wal::record::HEADER_SIZE]; aad[..6].copy_from_slice(b"BACKUP"); - key.decrypt(0, &aad, &raw_bytes[4..]) + key.decrypt(key.epoch(), 0, &aad, &raw_bytes[4..]) .map_err(|e| sqlstate_error("XX000", &format!("backup decryption failed: {e}")))? } else { raw_bytes diff --git a/nodedb/src/control/server/pgwire/handler/prepared/execute.rs b/nodedb/src/control/server/pgwire/handler/prepared/execute.rs index a33a6378..d4cd2e77 100644 --- a/nodedb/src/control/server/pgwire/handler/prepared/execute.rs +++ b/nodedb/src/control/server/pgwire/handler/prepared/execute.rs @@ -5,14 +5,17 @@ //! all DDL dispatch, transaction handling, and permission checks. use std::fmt::Debug; +use std::sync::Arc; use bytes::Bytes; +use futures::StreamExt; use futures::sink::Sink; use pgwire::api::portal::Portal; -use pgwire::api::results::Response; +use pgwire::api::results::{DataRowEncoder, FieldInfo, QueryResponse, Response}; use pgwire::api::{ClientInfo, ClientPortalStore, Type}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pgwire::messages::PgWireBackendMessage; +use sonic_rs; use super::super::core::NodeDbPgHandler; use super::statement::ParsedStatement; @@ -39,6 +42,15 @@ impl NodeDbPgHandler { let stmt = &portal.statement.statement; let tenant_id = identity.tenant_id; + // DSL passthroughs (SEARCH, GRAPH, MATCH, UPSERT INTO, etc.) cannot be + // handled by the planned-SQL path. Route them through the same full DSL + // dispatcher used by the simple-query handler. DSL statements do not use + // SQL parameter placeholders, so bound parameters are intentionally ignored. + if stmt.is_dsl { + let mut results = self.execute_sql(&identity, &addr, &stmt.sql).await?; + return Ok(results.pop().unwrap_or(Response::EmptyQuery)); + } + // Convert pgwire binary parameters to typed ParamValues for AST binding. let params = convert_portal_params(&portal.parameters, &stmt.param_types)?; @@ -46,8 +58,118 @@ impl NodeDbPgHandler { let mut results = self .execute_planned_sql_with_params(&identity, &stmt.sql, tenant_id, &addr, ¶ms) .await?; - Ok(results.pop().unwrap_or(Response::EmptyQuery)) + let result = results.pop().unwrap_or(Response::EmptyQuery); + + // When the statement declared typed result columns via Describe, the + // client expects DataRow messages with one field per declared column. + // + // The generic `payload_to_response` path produces a single-column + // QueryResponse with the full JSON as one text field. In the extended- + // query protocol the RowDescription was already sent by Describe, so + // pgwire sends only the DataRow messages on Execute — the client maps + // them against the previously-described schema. A 1-field row against + // an N-column schema causes null values for columns 2..N. + // + // Fix: when result_fields is non-empty, consume the single-field stream, + // parse each JSON object, and re-encode with one pgwire field per + // declared column. + if !stmt.result_fields.is_empty() { + reproject_response(result, &stmt.result_fields).await + } else { + Ok(result) + } + } +} + +/// Re-encode a query response to match the column schema declared by Describe. +/// +/// Each DataRow from `payload_to_response` contains a single text field holding +/// a JSON object. We parse each object and extract fields in `result_fields` +/// order, producing a new QueryResponse whose rows have one field per declared +/// column. Missing fields are sent as SQL NULL. +/// +/// Non-query responses (execution tags) pass through unchanged. +async fn reproject_response( + response: Response, + result_fields: &[FieldInfo], +) -> PgWireResult { + let qr = match response { + Response::Query(qr) => qr, + other => return Ok(other), + }; + + let schema = Arc::new(result_fields.to_vec()); + let field_names: Vec = result_fields.iter().map(|f| f.name().to_string()).collect(); + + // Collect JSON objects from the single-column stream produced by + // payload_to_response. Each DataRow has exactly one field: a JSON string. + let json_rows = collect_json_rows(qr).await?; + + let mut pgwire_rows = Vec::with_capacity(json_rows.len()); + for obj in &json_rows { + let mut encoder = DataRowEncoder::new(schema.clone()); + for name in &field_names { + match obj.get(name) { + None | Some(serde_json::Value::Null) => { + let _ = encoder.encode_field(&Option::::None); + } + Some(v) => { + let text = match v { + serde_json::Value::String(s) => s.clone(), + other => other.to_string(), + }; + let _ = encoder.encode_field(&text); + } + } + } + pgwire_rows.push(Ok(encoder.take_row())); } + + Ok(Response::Query(QueryResponse::new( + schema, + futures::stream::iter(pgwire_rows), + ))) +} + +/// Consume a `QueryResponse` stream and decode the single text field of each +/// `DataRow` as a JSON object. +/// +/// `payload_to_response` always produces rows where field[0] is a JSON string. +/// The pgwire `DataRow.data` format is: for each field, 4-byte length (i32, +/// big-endian) followed by the field bytes. `-1` (0xFFFFFFFF) means SQL NULL. +async fn collect_json_rows(mut qr: QueryResponse) -> PgWireResult> { + let mut rows = Vec::new(); + while let Some(row_result) = qr.data_rows.next().await { + let row = row_result?; + // Decode field[0] from the raw DataRow wire format. + let text = decode_first_field_text(&row.data); + if let Some(t) = text { + let val: serde_json::Value = + sonic_rs::from_str(t).unwrap_or_else(|_| serde_json::Value::String(t.to_string())); + rows.push(val); + } + } + Ok(rows) +} + +/// Decode the text bytes of the first field from a pgwire `DataRow` wire buffer. +/// +/// Wire format: for each field, 4-byte big-endian length followed by bytes. +/// Returns `None` for NULL fields or invalid encodings. +fn decode_first_field_text(data: &bytes::BytesMut) -> Option<&str> { + if data.len() < 4 { + return None; + } + let len = i32::from_be_bytes([data[0], data[1], data[2], data[3]]); + if len < 0 { + // NULL field. + return None; + } + let len = len as usize; + if data.len() < 4 + len { + return None; + } + std::str::from_utf8(&data[4..4 + len]).ok() } /// Convert pgwire portal parameters to typed `ParamValue` for AST-level binding. @@ -156,4 +278,27 @@ mod tests { assert!(matches!(result[0], nodedb_sql::ParamValue::Bool(v) if v == expected)); } } + + #[test] + fn decode_first_field_text_normal() { + // Wire format: 4-byte length (big-endian) + UTF-8 bytes. + let text = b"hello"; + let mut data = bytes::BytesMut::new(); + data.extend_from_slice(&(text.len() as i32).to_be_bytes()); + data.extend_from_slice(text); + assert_eq!(decode_first_field_text(&data), Some("hello")); + } + + #[test] + fn decode_first_field_text_null() { + // -1 length means SQL NULL. + let mut data = bytes::BytesMut::new(); + data.extend_from_slice(&(-1i32).to_be_bytes()); + assert_eq!(decode_first_field_text(&data), None); + } + + #[test] + fn decode_first_field_text_empty() { + assert_eq!(decode_first_field_text(&bytes::BytesMut::new()), None); + } } diff --git a/nodedb/src/control/server/pgwire/handler/prepared/parser.rs b/nodedb/src/control/server/pgwire/handler/prepared/parser.rs index d24f5a52..03d37585 100644 --- a/nodedb/src/control/server/pgwire/handler/prepared/parser.rs +++ b/nodedb/src/control/server/pgwire/handler/prepared/parser.rs @@ -112,10 +112,17 @@ impl QueryParser for NodeDbQueryParser { .unwrap_or(1); let (param_types, result_fields) = self.try_infer_types(sql, types, tenant_id); + // If type inference produced no result fields and the SQL matches a + // known DSL prefix, mark the statement as a DSL passthrough. The + // Execute handler will route it through the full DSL dispatcher + // (same as the simple-query path) instead of `execute_planned_sql_with_params`. + let is_dsl = result_fields.is_empty() && is_dsl_statement(sql); + Ok(ParsedStatement { sql: sql.to_owned(), param_types, result_fields, + is_dsl, }) } @@ -136,6 +143,25 @@ impl QueryParser for NodeDbQueryParser { } } +/// Return true if `sql` starts with a DSL keyword that `plan_sql` cannot parse. +/// +/// Mirrors the prefix checks in `ddl/router/dsl.rs` so the extended-query +/// Parse handler can mark such statements as DSL passthroughs and route them +/// through the DSL dispatcher at Execute time. +fn is_dsl_statement(sql: &str) -> bool { + let upper = sql.trim().to_uppercase(); + upper.starts_with("SEARCH ") + || upper.starts_with("GRAPH ") + || upper.starts_with("MATCH ") + || upper.starts_with("OPTIONAL MATCH ") + || upper.starts_with("CRDT MERGE ") + || upper.starts_with("UPSERT INTO ") + || upper.starts_with("CREATE VECTOR INDEX ") + || upper.starts_with("CREATE FULLTEXT INDEX ") + || upper.starts_with("CREATE SEARCH INDEX ") + || upper.starts_with("CREATE SPARSE INDEX ") +} + /// Count $1, $2, ... placeholders in SQL text. fn count_placeholders(sql: &str) -> usize { let mut max_idx = 0usize; diff --git a/nodedb/src/control/server/pgwire/handler/prepared/statement.rs b/nodedb/src/control/server/pgwire/handler/prepared/statement.rs index cbbde8fa..ed5b0b45 100644 --- a/nodedb/src/control/server/pgwire/handler/prepared/statement.rs +++ b/nodedb/src/control/server/pgwire/handler/prepared/statement.rs @@ -21,4 +21,8 @@ pub struct ParsedStatement { /// Result column schema inferred from the logical plan. /// Empty for DML statements (INSERT/UPDATE/DELETE). pub result_fields: Vec, + /// True when the SQL is a DSL statement (SEARCH, GRAPH, MATCH, UPSERT INTO, + /// etc.) that `plan_sql` cannot parse. The Execute handler routes these + /// through the full DSL dispatcher instead of `execute_planned_sql_with_params`. + pub is_dsl: bool, } diff --git a/nodedb/src/control/server/resp/listener.rs b/nodedb/src/control/server/resp/listener.rs index 7fc6b973..7e93e654 100644 --- a/nodedb/src/control/server/resp/listener.rs +++ b/nodedb/src/control/server/resp/listener.rs @@ -102,16 +102,24 @@ impl RespListener { if let Some(ref acceptor) = tls_acceptor { let acceptor = acceptor.clone(); connections.spawn(async move { - match acceptor.accept(stream).await { - Ok(tls_stream) => { + match tokio::time::timeout( + std::time::Duration::from_secs(10), + acceptor.accept(stream), + ) + .await + { + Ok(Ok(tls_stream)) => { let cs = ConnStream::tls(tls_stream); if let Err(e) = handle_connection(cs, peer, &state).await { debug!(%peer, error = %e, "RESP TLS connection error"); } } - Err(e) => { + Ok(Err(e)) => { warn!(%peer, error = %e, "RESP TLS handshake failed"); } + Err(_) => { + warn!(%peer, "RESP TLS handshake timed out"); + } } drop(permit); }); diff --git a/nodedb/src/data/executor/core_loop/mod.rs b/nodedb/src/data/executor/core_loop/mod.rs index 6232dd13..15b02d3c 100644 --- a/nodedb/src/data/executor/core_loop/mod.rs +++ b/nodedb/src/data/executor/core_loop/mod.rs @@ -155,6 +155,13 @@ pub struct CoreLoop { pub(in crate::data::executor) columnar_engines: HashMap, + /// Flushed columnar segment bytes, keyed by "{tid}:{collection}". + /// Each entry is a list of encoded segment buffers produced by `SegmentWriter`. + /// Kept in memory so `scan_columnar` can read rows that were drained from the + /// active memtable during a flush (otherwise those rows would be lost until a + /// real on-disk segment reader is wired up). + pub(in crate::data::executor) columnar_flushed_segments: HashMap>>, + /// Per-collection max WAL LSN that has been ingested into the memtable. /// Used by the WAL catch-up deduplication: if a catch-up record's LSN /// is <= this value, the Data Plane skips it (already ingested). @@ -283,6 +290,7 @@ impl CoreLoop { ), columnar_memtables: HashMap::new(), columnar_engines: HashMap::new(), + columnar_flushed_segments: HashMap::new(), ts_max_ingested_lsn: HashMap::new(), last_ts_ingest: None, ts_last_value_caches: HashMap::new(), diff --git a/nodedb/src/data/executor/handlers/accum.rs b/nodedb/src/data/executor/handlers/accum.rs new file mode 100644 index 00000000..e269d72f --- /dev/null +++ b/nodedb/src/data/executor/handlers/accum.rs @@ -0,0 +1,371 @@ +//! Streaming aggregate accumulators for the generic GROUP BY path. +//! +//! Each `AggAccum` variant holds only the derived state needed to compute the +//! final aggregate result — no raw document bytes are retained. Memory per +//! group is O(num_aggregates × accumulator_size) regardless of how many +//! documents match that group. + +use std::collections::HashSet; + +use crate::bridge::physical_plan::AggregateSpec; +use nodedb_types::Value; + +/// Maximum items collected by materializing aggregates (`array_agg`, +/// `array_agg_distinct`, `percentile_cont`, `string_agg`). +pub(super) const ARRAY_AGG_CAP: usize = 10_000; + +/// Per-(group, aggregate-spec) running accumulator. +pub(super) enum AggAccum { + /// count(*) or count(field). + Count { n: u64 }, + /// sum / avg: Kahan-compensated running sum + count. + SumAvg { sum: f64, comp: f64, n: u64 }, + /// min. + Min { best: Option }, + /// max. + Max { best: Option }, + /// count_distinct: set of raw msgpack bytes. + CountDistinct { seen: HashSet> }, + /// stddev / variance variants: Welford M2 accumulator. + Welford { n: u64, mean: f64, m2: f64 }, + /// approx_count_distinct: HyperLogLog. + Hll { + hll: nodedb_types::approx::HyperLogLog, + }, + /// approx_percentile: t-digest. + TDigest { + digest: nodedb_types::approx::TDigest, + }, + /// approx_topk: space-saving. + TopK { + ss: nodedb_types::approx::SpaceSaving, + k: usize, + }, + /// array_agg (capped). + ArrayAgg { values: Vec }, + /// array_agg_distinct (capped). + ArrayAggDistinct { + seen: HashSet>, + values: Vec, + }, + /// percentile_cont (capped). + PercentileCont { values: Vec, pct: f64 }, + /// string_agg / group_concat (capped). + StringAgg { parts: Vec }, +} + +impl AggAccum { + pub(super) fn new(agg: &AggregateSpec) -> Self { + match agg.function.as_str() { + "count" => AggAccum::Count { n: 0 }, + "sum" | "avg" => AggAccum::SumAvg { + sum: 0.0, + comp: 0.0, + n: 0, + }, + "min" => AggAccum::Min { best: None }, + "max" => AggAccum::Max { best: None }, + "count_distinct" => AggAccum::CountDistinct { + seen: HashSet::new(), + }, + "stddev" | "stddev_pop" | "stddev_samp" | "variance" | "var_pop" | "var_samp" => { + AggAccum::Welford { + n: 0, + mean: 0.0, + m2: 0.0, + } + } + "approx_count_distinct" => AggAccum::Hll { + hll: nodedb_types::approx::HyperLogLog::new(), + }, + "approx_percentile" => AggAccum::TDigest { + digest: nodedb_types::approx::TDigest::new(), + }, + "approx_topk" => { + let k: usize = agg + .field + .find(':') + .and_then(|i| agg.field[..i].parse().ok()) + .unwrap_or(10); + AggAccum::TopK { + ss: nodedb_types::approx::SpaceSaving::new(k), + k, + } + } + "array_agg" => AggAccum::ArrayAgg { values: Vec::new() }, + "array_agg_distinct" => AggAccum::ArrayAggDistinct { + seen: HashSet::new(), + values: Vec::new(), + }, + "percentile_cont" => { + let pct = agg + .field + .find(':') + .and_then(|i| agg.field[..i].parse().ok()) + .unwrap_or(0.5); + AggAccum::PercentileCont { + values: Vec::new(), + pct, + } + } + "string_agg" | "group_concat" => AggAccum::StringAgg { parts: Vec::new() }, + _ => AggAccum::Count { n: 0 }, + } + } + + /// Feed one document into this accumulator. + pub(super) fn feed(&mut self, agg: &AggregateSpec, doc: &[u8]) { + use nodedb_query::msgpack_scan::aggregate_helpers as ah; + match self { + AggAccum::Count { n } => { + if (agg.field == "*" && agg.expr.is_none()) + || ah::extract_non_null(doc, &agg.field, agg.expr.as_ref()).is_some() + { + *n += 1; + } + } + AggAccum::SumAvg { sum, comp, n } => { + if let Some(v) = ah::extract_f64(doc, &agg.field, agg.expr.as_ref()) { + let y = v - *comp; + let t = *sum + y; + *comp = (t - *sum) - y; + *sum = t; + *n += 1; + } + } + AggAccum::Min { best } => { + if let Some(v) = ah::extract_value(doc, &agg.field, agg.expr.as_ref()) { + if v.is_null() { + return; + } + let replace = match best { + None => true, + Some(cur) => { + nodedb_query::value_ops::compare_values(&v, cur) + == std::cmp::Ordering::Less + } + }; + if replace { + *best = Some(v); + } + } + } + AggAccum::Max { best } => { + if let Some(v) = ah::extract_value(doc, &agg.field, agg.expr.as_ref()) { + if v.is_null() { + return; + } + let replace = match best { + None => true, + Some(cur) => { + nodedb_query::value_ops::compare_values(&v, cur) + == std::cmp::Ordering::Greater + } + }; + if replace { + *best = Some(v); + } + } + } + AggAccum::CountDistinct { seen } => { + if let Some(bytes) = ah::extract_bytes(doc, &agg.field, agg.expr.as_ref()) + && bytes != [0xc0u8] + { + seen.insert(bytes); + } + } + AggAccum::Welford { n, mean, m2 } => { + if let Some(v) = ah::extract_f64(doc, &agg.field, agg.expr.as_ref()) { + *n += 1; + let delta = v - *mean; + *mean += delta / *n as f64; + let delta2 = v - *mean; + *m2 += delta * delta2; + } + } + AggAccum::Hll { hll } => { + if let Some(bytes) = ah::extract_bytes(doc, &agg.field, agg.expr.as_ref()) + && bytes != [0xc0u8] + { + hll.add(fnv1a(&bytes)); + } + } + AggAccum::TDigest { digest } => { + let actual = field_after_colon(&agg.field); + if let Some(v) = ah::extract_f64(doc, actual, agg.expr.as_ref()) { + digest.add(v); + } + } + AggAccum::TopK { ss, .. } => { + let actual = field_after_colon(&agg.field); + if let Some(bytes) = ah::extract_bytes(doc, actual, agg.expr.as_ref()) + && bytes != [0xc0u8] + { + ss.add(fnv1a(&bytes)); + } + } + AggAccum::ArrayAgg { values } => { + if values.len() < ARRAY_AGG_CAP + && let Some(v) = ah::extract_value(doc, &agg.field, agg.expr.as_ref()) + && !v.is_null() + { + values.push(v); + } + } + AggAccum::ArrayAggDistinct { seen, values } => { + if values.len() < ARRAY_AGG_CAP + && let Some(bytes) = ah::extract_bytes(doc, &agg.field, agg.expr.as_ref()) + && bytes != [0xc0u8] + && seen.insert(bytes) + && let Some(v) = ah::extract_value(doc, &agg.field, agg.expr.as_ref()) + { + values.push(v); + } + } + AggAccum::PercentileCont { values, .. } => { + let actual = field_after_colon(&agg.field); + if values.len() < ARRAY_AGG_CAP + && let Some(v) = ah::extract_f64(doc, actual, agg.expr.as_ref()) + { + values.push(v); + } + } + AggAccum::StringAgg { parts } => { + if parts.len() < ARRAY_AGG_CAP + && let Some(s) = ah::extract_str(doc, &agg.field, agg.expr.as_ref()) + { + parts.push(s); + } + } + } + } + + /// Consume the accumulator and produce the final `Value`. + pub(super) fn finalize(self, agg: &AggregateSpec) -> Value { + match self { + AggAccum::Count { n } => Value::Integer(n as i64), + AggAccum::SumAvg { sum, n, .. } => { + if agg.function == "avg" { + if n == 0 { + Value::Null + } else { + Value::Float(sum / n as f64) + } + } else { + Value::Float(sum) + } + } + AggAccum::Min { best } => best.unwrap_or(Value::Null), + AggAccum::Max { best } => best.unwrap_or(Value::Null), + AggAccum::CountDistinct { seen } => Value::Integer(seen.len() as i64), + AggAccum::Welford { n, mean: _, m2 } => { + if n < 2 { + return Value::Null; + } + let population = matches!( + agg.function.as_str(), + "stddev" | "stddev_pop" | "variance" | "var_pop" + ); + let divisor = if population { n as f64 } else { (n - 1) as f64 }; + let variance = m2 / divisor; + let result = if agg.function.contains("stddev") { + variance.sqrt() + } else { + variance + }; + Value::Float(result) + } + AggAccum::Hll { hll } => Value::Integer(hll.estimate().round() as i64), + AggAccum::TDigest { digest } => { + let pct = agg + .field + .find(':') + .and_then(|i| agg.field[..i].parse().ok()) + .unwrap_or(0.5); + let r = digest.quantile(pct); + if r.is_nan() { + Value::Null + } else { + Value::Float(r) + } + } + AggAccum::TopK { ss, k } => { + let arr: Vec = ss + .top_k() + .into_iter() + .take(k) + .map(|(item, count, error)| { + Value::Object( + [ + ("item".to_string(), Value::Integer(item as i64)), + ("count".to_string(), Value::Integer(count as i64)), + ("error".to_string(), Value::Integer(error as i64)), + ] + .into_iter() + .collect(), + ) + }) + .collect(); + Value::Array(arr) + } + AggAccum::ArrayAgg { values } => Value::Array(values), + AggAccum::ArrayAggDistinct { values, .. } => Value::Array(values), + AggAccum::PercentileCont { mut values, pct } => { + if values.is_empty() { + return Value::Null; + } + values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + let idx = (pct * (values.len() - 1) as f64).clamp(0.0, (values.len() - 1) as f64); + let lo = idx.floor() as usize; + let hi = idx.ceil() as usize; + let frac = idx - lo as f64; + Value::Float(values[lo] * (1.0 - frac) + values[hi] * frac) + } + AggAccum::StringAgg { parts } => Value::String(parts.join(",")), + } + } +} + +/// Per-group running state: one `AggAccum` per aggregate spec. +pub(super) struct GroupState { + pub(super) accums: Vec, +} + +impl GroupState { + pub(super) fn new(aggregates: &[AggregateSpec]) -> Self { + Self { + accums: aggregates.iter().map(AggAccum::new).collect(), + } + } + + pub(super) fn feed(&mut self, aggregates: &[AggregateSpec], doc: &[u8]) { + for (accum, agg) in self.accums.iter_mut().zip(aggregates) { + accum.feed(agg, doc); + } + } + + pub(super) fn finalize(self, aggregates: &[AggregateSpec]) -> Vec<(String, Value)> { + self.accums + .into_iter() + .zip(aggregates) + .map(|(accum, agg)| (agg.alias.clone(), accum.finalize(agg))) + .collect() + } +} + +/// FNV-1a hash (matches the implementation in nodedb-query aggregate.rs). +#[inline] +pub(super) fn fnv1a(bytes: &[u8]) -> u64 { + let mut h: u64 = 0xcbf29ce484222325; + for &b in bytes { + h ^= b as u64; + h = h.wrapping_mul(0x100000001b3); + } + h +} + +/// Extract the actual field name from "prefix:field" format (e.g. "0.95:latency"). +#[inline] +pub(super) fn field_after_colon(field: &str) -> &str { + field.find(':').map(|i| &field[i + 1..]).unwrap_or(field) +} diff --git a/nodedb/src/data/executor/handlers/aggregate.rs b/nodedb/src/data/executor/handlers/aggregate.rs index e815c995..e677ef50 100644 --- a/nodedb/src/data/executor/handlers/aggregate.rs +++ b/nodedb/src/data/executor/handlers/aggregate.rs @@ -1,8 +1,17 @@ //! Aggregate handler: GROUP BY, HAVING, and aggregate function execution. +//! +//! The generic (non-columnar) path uses **streaming accumulators** — see +//! `accum.rs`. Raw document bytes are never stored; only the extracted +//! scalar / approximate values needed by each aggregate function are kept. +//! Memory is O(num_groups × num_aggregates) instead of +//! O(total_matching_docs × avg_doc_size). + +use std::collections::HashMap; use sonic_rs; use tracing::debug; +use super::accum::GroupState; use crate::bridge::envelope::{ErrorCode, Response}; use crate::bridge::physical_plan::AggregateSpec; use crate::bridge::scan_filter::ScanFilter; @@ -11,10 +20,8 @@ use crate::data::executor::task::ExecutionTask; use nodedb_query::agg_key::canonical_agg_key; use nodedb_query::msgpack_scan; -/// Build a cache key for an aggregate query. -/// -/// Format: `"{tid}:{collection}\0{group_fields}\0{agg_ops}"`. -/// Null bytes separate sections to avoid ambiguity with field names. +// ── Cache key ────────────────────────────────────────────────────────────── + fn aggregate_cache_key( tid: u32, collection: &str, @@ -60,37 +67,6 @@ fn aggregate_cache_key( key } -/// Group a single document into the binary_groups map. -/// -/// Applies filter predicates, computes group key, and stores the raw -/// document bytes for later aggregation. -fn group_doc( - value: &[u8], - group_by: &[String], - filter_predicates: &[ScanFilter], - use_field_index: bool, - binary_groups: &mut std::collections::HashMap>>, -) { - if use_field_index { - let idx = msgpack_scan::FieldIndex::build(value, 0) - .unwrap_or_else(msgpack_scan::FieldIndex::empty); - if !filter_predicates - .iter() - .all(|f| f.matches_binary_indexed(value, &idx)) - { - return; - } - let key = msgpack_scan::group_key::build_group_key_indexed(value, group_by, &idx); - binary_groups.entry(key).or_default().push(value.to_vec()); - } else { - if !filter_predicates.iter().all(|f| f.matches_binary(value)) { - return; - } - let key = msgpack_scan::build_group_key(value, group_by); - binary_groups.entry(key).or_default().push(value.to_vec()); - } -} - fn legacy_aggregate_pairs(aggregates: &[AggregateSpec]) -> Option> { aggregates .iter() @@ -130,6 +106,8 @@ fn apply_user_aliases_to_rows(rows: &mut [serde_json::Value], aggregates: &[Aggr } } +// ── CoreLoop impl ────────────────────────────────────────────────────────── + impl CoreLoop { #[allow(clippy::too_many_arguments)] pub(in crate::data::executor) fn execute_aggregate( @@ -148,8 +126,6 @@ impl CoreLoop { debug!(core = self.core_id, %collection, group_fields = group_by.len(), aggs = aggregates.len(), "aggregate"); // Fast path: incremental aggregate cache. - // If we've cached the result for this exact (collection, group_by, aggregates) - // combination and there are no filters/having, return cached result directly. if filters.is_empty() && having.is_empty() { let cache_key = aggregate_cache_key( tid, @@ -166,9 +142,6 @@ impl CoreLoop { } // Fast path: index-backed COUNT/GROUP BY. - // When GROUP BY has a single field, no filters, no HAVING, and the - // only aggregate is COUNT(*), scan the INDEXES table directly. - // No document table access — O(index_entries) instead of O(documents). if group_by.len() == 1 && filters.is_empty() && having.is_empty() @@ -177,12 +150,9 @@ impl CoreLoop { && aggregates[0].function == "count" { let field = &group_by[0]; - // Empty index — fall through to full scan (documents may exist - // without index entries if no secondary indexes are declared). if let Ok(groups) = self.sparse.scan_index_groups(tid, collection, field) && !groups.is_empty() { - // Build result rows as raw msgpack — no serde_json::Value. let mut payload_buf = Vec::with_capacity(groups.len() * 64); let row_count = groups.len().min(limit); let count_key = aggregates[0] @@ -195,8 +165,7 @@ impl CoreLoop { msgpack_scan::write_kv_str(&mut payload_buf, field, &value); msgpack_scan::write_kv_i64(&mut payload_buf, &count_key, count as i64); } - let results_payload = payload_buf; - return match Ok::, crate::Error>(results_payload) { + return match Ok::, crate::Error>(payload_buf) { Ok(payload) => self.response_with_payload(task, payload), Err(e) => self.response_error( task, @@ -208,22 +177,14 @@ impl CoreLoop { } } - // Aggregates must scan all matching documents for correct results. - // Cap at aggregate_scan_cap to prevent OOM on unbounded collections. let scan_limit = self.query_tuning.aggregate_scan_cap; - // If collection has columnar memtable data, read from there. - // Works for all columnar profiles: plain, timeseries, spatial. - // Spatial inserts write to both sparse (R-tree) and columnar (scans/aggregates). let columnar_mt = self .columnar_memtables .get(collection) .filter(|mt| !mt.is_empty()); // Fast path: native columnar aggregation. - // Groups directly on symbol IDs (u32) instead of JSON-serialized strings. - // Accumulates in-place without document construction. - // Falls back to generic path for complex filters (OR, string comparisons). if let Some(mt) = columnar_mt.filter(|_| sub_group_by.is_empty() && sub_aggregates.is_empty()) { @@ -250,7 +211,6 @@ impl CoreLoop { scan_limit, ) }) { - // Apply HAVING filters. if !having.is_empty() { let having_predicates: Vec = match zerompk::from_msgpack(having) { Ok(h) => h, @@ -295,13 +255,13 @@ impl CoreLoop { ), }; } - // Native path returned None — fall through to generic path. } - // ── Streaming aggregation: process documents in chunks ── - // Instead of loading all documents into memory, scan in chunks of - // 10K docs, group + aggregate each chunk, then merge partial results. - // Memory: O(chunk_size + num_groups) instead of O(all_docs). + // ── Streaming aggregation ────────────────────────────────────────── + // Documents are processed one at a time. Per-group accumulators hold + // only the derived scalar / approximate state needed for the final + // result — no raw document bytes are retained. + // Memory: O(num_groups × num_aggregates) instead of O(all_docs). let filter_predicates: Vec = if filters.is_empty() { Vec::new() @@ -316,28 +276,51 @@ impl CoreLoop { }; let use_field_index = filter_predicates.len() + group_by.len() >= 2; + let need_sub = !sub_group_by.is_empty() && !sub_aggregates.is_empty(); - // Accumulate per-group doc bytes across all chunks. - // Key: group_key string, Value: collected raw doc bytes for final aggregation. - let mut binary_groups: std::collections::HashMap>> = - std::collections::HashMap::new(); + // outer_group_key → GroupState + let mut groups: HashMap = HashMap::new(); + // outer_group_key → sub_group_key → GroupState + let mut sub_groups: HashMap> = HashMap::new(); let chunk_size = 10_000; - // Universal scan: routes to the correct engine (KV, columnar, sparse/strict) - // and normalizes all results to standard msgpack maps. let scan_result = self .scan_collection(tid, collection, scan_limit) .map(|docs| { for chunk in docs.chunks(chunk_size) { for (_, value) in chunk { - group_doc( - value, - group_by, - &filter_predicates, - use_field_index, - &mut binary_groups, - ); + let outer_key = if use_field_index { + let idx = msgpack_scan::FieldIndex::build(value, 0) + .unwrap_or_else(msgpack_scan::FieldIndex::empty); + if !filter_predicates + .iter() + .all(|f| f.matches_binary_indexed(value, &idx)) + { + continue; + } + msgpack_scan::group_key::build_group_key_indexed(value, group_by, &idx) + } else { + if !filter_predicates.iter().all(|f| f.matches_binary(value)) { + continue; + } + msgpack_scan::build_group_key(value, group_by) + }; + + groups + .entry(outer_key.clone()) + .or_insert_with(|| GroupState::new(aggregates)) + .feed(aggregates, value); + + if need_sub { + let sub_key = msgpack_scan::build_group_key(value, sub_group_by); + sub_groups + .entry(outer_key) + .or_default() + .entry(sub_key) + .or_insert_with(|| GroupState::new(sub_aggregates)) + .feed(sub_aggregates, value); + } } } }); @@ -345,12 +328,12 @@ impl CoreLoop { match scan_result { Ok(()) => { let mut results: Vec = Vec::new(); - for (group_key, group_docs) in &binary_groups { + + for (group_key, state) in groups { let mut row = serde_json::Map::new(); - // Insert group-by field values into the result row. if !group_by.is_empty() - && let Ok(parts) = sonic_rs::from_str::>(group_key) + && let Ok(parts) = sonic_rs::from_str::>(&group_key) { for (i, field) in group_by.iter().enumerate() { let val = parts.get(i).cloned().unwrap_or(serde_json::Value::Null); @@ -358,32 +341,18 @@ impl CoreLoop { } } - let doc_slices: Vec<&[u8]> = group_docs.iter().map(|d| d.as_slice()).collect(); - - for agg in aggregates { - let val = msgpack_scan::compute_aggregate_binary( - &agg.function, - &agg.field, - agg.expr.as_ref(), - &doc_slices, - ); + for (alias, val) in state.finalize(aggregates) { let json_val: serde_json::Value = val.into(); - row.insert(agg.alias.clone(), json_val); + row.insert(alias, json_val); } - // Nested sub-aggregation on raw bytes. - if !sub_group_by.is_empty() && !sub_aggregates.is_empty() { - let mut sub_groups: std::collections::HashMap> = - std::collections::HashMap::new(); - for doc_bytes in &doc_slices { - let sub_key = msgpack_scan::build_group_key(doc_bytes, sub_group_by); - sub_groups.entry(sub_key).or_default().push(doc_bytes); - } - - let mut sub_results = Vec::new(); - for (sub_key, sub_docs) in &sub_groups { + if need_sub { + let sub_map = sub_groups.remove(&group_key).unwrap_or_default(); + let mut sub_results: Vec = Vec::new(); + for (sub_key, sub_state) in sub_map { let mut sub_row = serde_json::Map::new(); - if let Ok(parts) = sonic_rs::from_str::>(sub_key) + if let Ok(parts) = + sonic_rs::from_str::>(&sub_key) { for (i, field) in sub_group_by.iter().enumerate() { let val = @@ -391,15 +360,9 @@ impl CoreLoop { sub_row.insert(field.clone(), val); } } - for agg in sub_aggregates { - let val = msgpack_scan::compute_aggregate_binary( - &agg.function, - &agg.field, - agg.expr.as_ref(), - sub_docs, - ); + for (alias, val) in sub_state.finalize(sub_aggregates) { let json_val: serde_json::Value = val.into(); - sub_row.insert(agg.alias.clone(), json_val); + sub_row.insert(alias, json_val); } let mut sub_value = serde_json::Value::Object(sub_row); apply_user_aliases_to_rows( @@ -421,7 +384,11 @@ impl CoreLoop { let having_predicates: Vec = match zerompk::from_msgpack(having) { Ok(f) => f, Err(e) => { - tracing::warn!(core = self.core_id, error = %e, "HAVING predicate deserialization failed (schemaless)"); + tracing::warn!( + core = self.core_id, + error = %e, + "HAVING predicate deserialization failed (schemaless)" + ); Vec::new() } }; @@ -438,7 +405,6 @@ impl CoreLoop { match super::super::response_codec::encode_json_vec(&results) { Ok(payload) => { - // Cache the result for future identical queries. if filters.is_empty() && having.is_empty() { let cache_key = aggregate_cache_key( tid, @@ -448,7 +414,6 @@ impl CoreLoop { sub_group_by, sub_aggregates, ); - // Bounded cache: max 256 entries per core. if self.aggregate_cache.len() < 256 { self.aggregate_cache.insert(cache_key, payload.clone()); } diff --git a/nodedb/src/data/executor/handlers/columnar_write.rs b/nodedb/src/data/executor/handlers/columnar_write.rs index c6cc08da..4f31267f 100644 --- a/nodedb/src/data/executor/handlers/columnar_write.rs +++ b/nodedb/src/data/executor/handlers/columnar_write.rs @@ -98,6 +98,40 @@ impl CoreLoop { } } + // Flush memtable to a segment if the threshold has been reached. + if engine.should_flush() { + let new_segment_id = engine.next_segment_id(); + let (schema, columns, row_count) = engine.memtable_mut().drain_optimized(); + if row_count > 0 { + match nodedb_columnar::SegmentWriter::plain() + .write_segment(&schema, &columns, row_count) + { + Ok(bytes) => { + self.columnar_flushed_segments + .entry(collection.to_string()) + .or_default() + .push(bytes); + tracing::debug!( + core = self.core_id, + %collection, + new_segment_id, + row_count, + "columnar memtable flushed and segment bytes retained in memory" + ); + } + Err(e) => { + tracing::warn!( + core = self.core_id, + %collection, + error = %e, + "columnar segment encode failed; flushed rows may be lost" + ); + } + } + } + engine.on_memtable_flushed(new_segment_id); + } + // Populate R-tree for geometry columns so spatial predicates work. { let tid = task.request.tenant_id; diff --git a/nodedb/src/data/executor/handlers/mod.rs b/nodedb/src/data/executor/handlers/mod.rs index 9cd56ab7..cfcd0fc8 100644 --- a/nodedb/src/data/executor/handlers/mod.rs +++ b/nodedb/src/data/executor/handlers/mod.rs @@ -1,3 +1,4 @@ +mod accum; pub mod aggregate; pub mod bulk_dml; pub mod columnar_agg; diff --git a/nodedb/src/data/executor/handlers/vector_search.rs b/nodedb/src/data/executor/handlers/vector_search.rs index 0c34619e..c2773df3 100644 --- a/nodedb/src/data/executor/handlers/vector_search.rs +++ b/nodedb/src/data/executor/handlers/vector_search.rs @@ -351,11 +351,14 @@ impl CoreLoop { } } +/// Maximum allowed ef_search value. Prevents DoS via unbounded beam width. +const MAX_EF_SEARCH: usize = 8192; + /// Compute effective ef parameter for HNSW search. fn effective_ef(ef_search: usize, top_k: usize) -> usize { if ef_search > 0 { - ef_search.max(top_k) + ef_search.max(top_k).min(MAX_EF_SEARCH) } else { - top_k.saturating_mul(4).max(64) + top_k.saturating_mul(4).clamp(64, MAX_EF_SEARCH) } } diff --git a/nodedb/src/data/executor/scan_normalize.rs b/nodedb/src/data/executor/scan_normalize.rs index 91f18d88..b54da46b 100644 --- a/nodedb/src/data/executor/scan_normalize.rs +++ b/nodedb/src/data/executor/scan_normalize.rs @@ -108,26 +108,83 @@ impl CoreLoop { }; let schema = engine.schema(); - let rows: Vec<_> = engine.scan_memtable_rows().take(limit).collect(); - let mut results = Vec::with_capacity(rows.len()); - - for row in rows { - // Build a nodedb_types::Value::Object directly — no JSON intermediary. - let mut map = std::collections::HashMap::new(); - let mut id = String::new(); - for (i, col_def) in schema.columns.iter().enumerate() { - if i < row.len() { - if col_def.name == "id" - && let nodedb_types::value::Value::String(s) = &row[i] - { - id.clone_from(s); + let mut results = Vec::new(); + + // 1. Read from flushed segments (older rows drained from prior memtable flushes). + if let Some(segments) = self.columnar_flushed_segments.get(collection) { + for seg_bytes in segments { + if results.len() >= limit { + break; + } + let reader = match nodedb_columnar::SegmentReader::open(seg_bytes) { + Ok(r) => r, + Err(e) => { + tracing::warn!(error = %e, "failed to open flushed columnar segment for scan"); + continue; + } + }; + let seg_row_count = reader.row_count() as usize; + let remaining = limit - results.len(); + let take = seg_row_count.min(remaining); + + // Decode all columns for this segment. + let col_count = schema.columns.len(); + let mut decoded_cols = Vec::with_capacity(col_count); + let mut decode_ok = true; + for col_idx in 0..col_count { + match reader.read_column(col_idx) { + Ok(dc) => decoded_cols.push(dc), + Err(e) => { + tracing::warn!(error = %e, col_idx, "failed to decode columnar segment column"); + decode_ok = false; + break; + } + } + } + if !decode_ok { + continue; + } + + for row_idx in 0..take { + let mut map = std::collections::HashMap::new(); + let mut id = String::new(); + for (col_idx, col_def) in schema.columns.iter().enumerate() { + let val = decoded_col_to_value(&decoded_cols[col_idx], row_idx); + if col_def.name == "id" + && let nodedb_types::value::Value::String(s) = &val + { + id.clone_from(s); + } + map.insert(col_def.name.clone(), val); } - map.insert(col_def.name.clone(), row[i].clone()); + let ndb_val = nodedb_types::value::Value::Object(map); + let mp = nodedb_types::value_to_msgpack(&ndb_val).unwrap_or_default(); + results.push((id, mp)); } } - let ndb_val = nodedb_types::value::Value::Object(map); - let mp = nodedb_types::value_to_msgpack(&ndb_val).unwrap_or_default(); - results.push((id, mp)); + } + + // 2. Read from the live memtable (most-recent rows not yet flushed). + if results.len() < limit { + let remaining = limit - results.len(); + let rows: Vec<_> = engine.scan_memtable_rows().take(remaining).collect(); + for row in rows { + let mut map = std::collections::HashMap::new(); + let mut id = String::new(); + for (i, col_def) in schema.columns.iter().enumerate() { + if i < row.len() { + if col_def.name == "id" + && let nodedb_types::value::Value::String(s) = &row[i] + { + id.clone_from(s); + } + map.insert(col_def.name.clone(), row[i].clone()); + } + } + let ndb_val = nodedb_types::value::Value::Object(map); + let mp = nodedb_types::value_to_msgpack(&ndb_val).unwrap_or_default(); + results.push((id, mp)); + } } results @@ -175,3 +232,84 @@ impl CoreLoop { } } } + +/// Convert a single row from a `DecodedColumn` to a `nodedb_types::value::Value`. +/// +/// Returns `Value::Null` if the row index is out of range or the validity bit is false. +fn decoded_col_to_value( + col: &nodedb_columnar::reader::DecodedColumn, + row_idx: usize, +) -> nodedb_types::value::Value { + use nodedb_columnar::reader::DecodedColumn; + use nodedb_types::value::Value; + + match col { + DecodedColumn::Int64 { values, valid } => { + if row_idx < valid.len() && valid[row_idx] { + Value::Integer(values[row_idx]) + } else { + Value::Null + } + } + DecodedColumn::Float64 { values, valid } => { + if row_idx < valid.len() && valid[row_idx] { + Value::Float(values[row_idx]) + } else { + Value::Null + } + } + DecodedColumn::Timestamp { values, valid } => { + if row_idx < valid.len() && valid[row_idx] { + // Represent as integer microseconds (same as Value::Integer for timestamps). + Value::Integer(values[row_idx]) + } else { + Value::Null + } + } + DecodedColumn::Bool { values, valid } => { + if row_idx < valid.len() && valid[row_idx] { + Value::Bool(values[row_idx]) + } else { + Value::Null + } + } + DecodedColumn::Binary { + data, + offsets, + valid, + } => { + if row_idx < valid.len() && valid[row_idx] && row_idx + 1 < offsets.len() { + let start = offsets[row_idx] as usize; + let end = offsets[row_idx + 1] as usize; + if start <= end && end <= data.len() { + let bytes = &data[start..end]; + // Best-effort UTF-8 interpretation; fall back to bytes. + match std::str::from_utf8(bytes) { + Ok(s) => Value::String(s.to_string()), + Err(_) => Value::Bytes(bytes.to_vec()), + } + } else { + Value::Null + } + } else { + Value::Null + } + } + DecodedColumn::DictEncoded { + ids, + dictionary, + valid, + } => { + if row_idx < valid.len() && valid[row_idx] { + let id = ids[row_idx] as usize; + if id < dictionary.len() { + Value::String(dictionary[id].clone()) + } else { + Value::Null + } + } else { + Value::Null + } + } + } +} diff --git a/nodedb/src/storage/segment.rs b/nodedb/src/storage/segment.rs index 2067a4bd..b4fee287 100644 --- a/nodedb/src/storage/segment.rs +++ b/nodedb/src/storage/segment.rs @@ -176,7 +176,7 @@ pub fn read_encrypted_segment( if let Some(key) = key { let mut aad = [0u8; nodedb_wal::record::HEADER_SIZE]; aad[..4].copy_from_slice(b"SEGM"); - key.decrypt(footer.min_lsn.as_u64(), &aad, data) + key.decrypt(key.epoch(), footer.min_lsn.as_u64(), &aad, data) .map_err(|e| crate::Error::Storage { engine: "segment".into(), detail: format!("segment decryption failed: {e}"), diff --git a/nodedb/tests/executor_tests/test_columnar_aggregate.rs b/nodedb/tests/executor_tests/test_columnar_aggregate.rs index c01a23bc..ebd56060 100644 --- a/nodedb/tests/executor_tests/test_columnar_aggregate.rs +++ b/nodedb/tests/executor_tests/test_columnar_aggregate.rs @@ -124,3 +124,128 @@ fn columnar_having_uses_canonical_key_but_output_keeps_user_alias() { assert_eq!(rows[0]["city_count"].as_u64(), Some(2)); assert!(rows[0].get("count(*)").is_none()); } + +#[test] +fn columnar_insert_triggers_memtable_flush() { + // Spec: after inserting more rows than DEFAULT_FLUSH_THRESHOLD (65536), the + // memtable must be drained to a segment on disk rather than accumulating + // unbounded memory. + let mut ctx = make_ctx(); + + // Build a batch of 70000 rows — above the 65536 flush threshold. + let rows: Vec = (0..70_000) + .map(|i| { + serde_json::json!({ + "id": format!("r{i}"), + "v": i, + }) + }) + .collect(); + let payload = nodedb_types::json_to_msgpack(&serde_json::Value::Array(rows)).unwrap(); + + // The write must succeed without error. Before the fix this would succeed + // but silently accumulate all rows in RAM; after the fix the engine flushes + // the memtable to a segment once the threshold is crossed. + send_ok( + &mut ctx.core, + &mut ctx.tx, + &mut ctx.rx, + PhysicalPlan::Columnar(ColumnarOp::Insert { + collection: "large_col".into(), + payload, + format: "msgpack".into(), + }), + ); + + // All rows must be readable back — the segment flush must not lose data. + let doc_count = ctx + .core + .scan_collection(1, "large_col", 70_001) + .unwrap() + .len(); + assert_eq!( + doc_count, 70_000, + "all inserted rows must be scannable after flush" + ); +} + +#[test] +fn aggregate_group_by_does_not_require_full_materialization() { + // Spec: GROUP BY aggregation must return correct per-group results regardless + // of whether the implementation uses running aggregates (O(groups)) or + // full doc materialization (O(rows)). This test locks in correctness; + // the fix changes internal memory usage from O(N) to O(groups). + let mut ctx = make_ctx(); + + // Insert 1000 rows across 10 groups (g0..g9), each group gets 100 rows. + let rows: Vec = (0..1_000) + .map(|i| { + serde_json::json!({ + "id": format!("r{i}"), + "g": format!("g{}", i % 10), + "v": i, + }) + }) + .collect(); + let payload = nodedb_types::json_to_msgpack(&serde_json::Value::Array(rows)).unwrap(); + + send_ok( + &mut ctx.core, + &mut ctx.tx, + &mut ctx.rx, + PhysicalPlan::Columnar(ColumnarOp::Insert { + collection: "grouped".into(), + payload, + format: "msgpack".into(), + }), + ); + + let payload = send_ok( + &mut ctx.core, + &mut ctx.tx, + &mut ctx.rx, + PhysicalPlan::Query(QueryOp::Aggregate { + collection: "grouped".into(), + group_by: vec!["g".into()], + aggregates: vec![ + AggregateSpec { + function: "count".into(), + alias: "count(*)".into(), + user_alias: None, + field: "*".into(), + expr: None, + }, + AggregateSpec { + function: "sum".into(), + alias: "sum(v)".into(), + user_alias: None, + field: "v".into(), + expr: None, + }, + ], + filters: Vec::new(), + having: Vec::new(), + limit: 100, + sub_group_by: Vec::new(), + sub_aggregates: Vec::new(), + }), + ); + + let result = payload_value(&payload); + let result_rows = result + .as_array() + .unwrap_or_else(|| panic!("expected aggregate rows, got {result}")); + + assert_eq!( + result_rows.len(), + 10, + "GROUP BY must produce exactly 10 groups" + ); + for row in result_rows { + assert_eq!( + row["count(*)"].as_u64(), + Some(100), + "each group must contain exactly 100 rows, got: {row}" + ); + } +} diff --git a/nodedb/tests/sql_prepared_statements.rs b/nodedb/tests/sql_prepared_statements.rs index ace3a7a2..44eadf75 100644 --- a/nodedb/tests/sql_prepared_statements.rs +++ b/nodedb/tests/sql_prepared_statements.rs @@ -23,3 +23,86 @@ async fn prepare_execute_deallocate_lifecycle() { server.exec("DEALLOCATE ALL").await.unwrap(); server.expect_error("EXECUTE q1", "does not exist").await; } + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn prepared_search_vector_dsl() { + let server = TestServer::start().await; + + // Create a document collection and a vector index on the embedding field. + server + .exec("CREATE COLLECTION vec_ep TYPE document") + .await + .unwrap(); + server + .exec("CREATE VECTOR INDEX idx_vec_ep ON vec_ep METRIC cosine DIM 3") + .await + .unwrap(); + + // Insert a document with an embedding vector. + server + .exec("INSERT INTO vec_ep (id, embedding) VALUES ('v1', ARRAY[1.0, 0.0, 0.0])") + .await + .unwrap(); + + // DSL SEARCH statements must not be rejected by the extended-protocol path + // with "Expected: an SQL statement". The statement should succeed and return + // results (or an empty result set — the key is no parse-time rejection). + let result = server + .query_text("SEARCH vec_ep USING VECTOR(embedding, ARRAY[1.0, 0.0, 0.0], 3)") + .await; + assert!( + result.is_ok(), + "SEARCH via extended protocol must not be rejected: {:?}", + result.err() + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn prepared_upsert_dsl() { + let server = TestServer::start().await; + + server.exec("CREATE COLLECTION upsert_ep").await.unwrap(); + + // UPSERT INTO DSL statements must not be rejected by the extended-protocol + // path with "Expected: an SQL statement". + let result = server + .exec("UPSERT INTO upsert_ep { id: 'u1', name: 'alice' }") + .await; + assert!( + result.is_ok(), + "UPSERT INTO via extended protocol must not be rejected: {:?}", + result.err() + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn prepared_select_strict_doc_returns_data() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION strict_ep TYPE DOCUMENT STRICT \ + (id TEXT PRIMARY KEY, name TEXT)", + ) + .await + .unwrap(); + server + .exec("INSERT INTO strict_ep (id, name) VALUES ('a', 'alice')") + .await + .unwrap(); + + // SELECT on a STRICT doc collection via the extended-query protocol must + // return the inserted row with actual column values, not null/empty columns. + let rows = server + .query_text("SELECT id, name FROM strict_ep WHERE id = 'a'") + .await + .unwrap(); + assert!(!rows.is_empty(), "SELECT should return the inserted row"); + + // Regression guard: the row must contain actual data, not null. + assert!( + rows[0].contains("alice"), + "extended protocol must not return null columns for STRICT doc, got: {:?}", + rows[0] + ); +}