From b9047cd002f569631695c00eed600d4533e8d7d1 Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Thu, 16 Apr 2026 19:01:52 +0800 Subject: [PATCH 01/10] fix(wal): use random epoch in nonce to prevent reuse after WAL truncation Replace the LSN-only nonce derivation with a `[4-byte random epoch][8-byte LSN]` scheme. A fresh epoch is generated at construction time via getrandom, ensuring nonces are never reused across WAL lifetimes (process restart, snapshot restore, segment rotation) even when LSNs restart from 1. Update segment decryption to pass the epoch from the key so the AAD binding remains consistent. Refactor mmap_reader and reader for cleaner error paths. --- Cargo.lock | 1 + nodedb-wal/Cargo.toml | 1 + nodedb-wal/src/crypto.rs | 105 ++++++++++++++++++------- nodedb-wal/src/mmap_reader.rs | 100 ++++++++++++------------ nodedb-wal/src/reader.rs | 142 +++++++++++++++++++++------------- nodedb-wal/src/record.rs | 13 +++- nodedb-wal/src/writer.rs | 31 +++++--- nodedb/src/storage/segment.rs | 2 +- 8 files changed, 245 insertions(+), 150 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 69390359..c8646ce6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3994,6 +3994,7 @@ dependencies = [ "aes-gcm", "crc32c", "fluxbench", + "getrandom 0.3.4", "io-uring", "libc", "memmap2", diff --git a/nodedb-wal/Cargo.toml b/nodedb-wal/Cargo.toml index d21beb00..2171ca13 100644 --- a/nodedb-wal/Cargo.toml +++ b/nodedb-wal/Cargo.toml @@ -21,6 +21,7 @@ libc = { workspace = true } memmap2 = { workspace = true } io-uring = { workspace = true, optional = true } aes-gcm = { workspace = true } +getrandom = { workspace = true } [dev-dependencies] tokio = { workspace = true } diff --git a/nodedb-wal/src/crypto.rs b/nodedb-wal/src/crypto.rs index 75ec48b5..d19a23fe 100644 --- a/nodedb-wal/src/crypto.rs +++ b/nodedb-wal/src/crypto.rs @@ -4,7 +4,8 @@ //! - Header stays plaintext (needed for recovery scanning — magic, lsn, tenant_id) //! - Payload is encrypted before CRC computation //! - CRC covers the ciphertext (detects corruption of encrypted data) -//! - Nonce derived from LSN (deterministic — no extra storage, enables replay) +//! - Nonce = `[4-byte random epoch][8-byte LSN]` — epoch is generated per WAL +//! lifetime to prevent nonce reuse after snapshot restore or WAL truncation //! - Additional Authenticated Data (AAD) = header bytes (binds ciphertext to its header) //! //! On-disk format for encrypted payload: @@ -19,17 +20,27 @@ use aes_gcm::aead::{Aead, KeyInit}; use crate::error::{Result, WalError}; use crate::record::HEADER_SIZE; -/// AES-256-GCM key: exactly 32 bytes. +/// AES-256-GCM key with a random per-lifetime epoch for nonce disambiguation. +/// +/// The epoch is generated randomly at construction time. Each WAL lifetime +/// (process start, snapshot restore, segment creation) gets a fresh epoch, +/// ensuring that nonces are never reused even if LSNs restart from 1. #[derive(Clone)] pub struct WalEncryptionKey { cipher: Aes256Gcm, + /// Random 4-byte epoch: occupies the high 4 bytes of the 12-byte nonce. + /// Disambiguates nonces across WAL lifetimes with the same key. + epoch: [u8; 4], } impl WalEncryptionKey { - /// Create from a 32-byte key. + /// Create from a 32-byte key with a fresh random epoch. pub fn from_bytes(key: &[u8; 32]) -> Self { + let mut epoch = [0u8; 4]; + getrandom::fill(&mut epoch).expect("getrandom failed"); Self { cipher: Aes256Gcm::new(key.into()), + epoch, } } @@ -60,7 +71,7 @@ impl WalEncryptionKey { header_bytes: &[u8; HEADER_SIZE], plaintext: &[u8], ) -> Result> { - let nonce = lsn_to_nonce(lsn); + let nonce = lsn_to_nonce(&self.epoch, lsn); self.cipher .encrypt( &nonce, @@ -74,18 +85,25 @@ impl WalEncryptionKey { }) } + /// The random epoch for this key instance. + pub fn epoch(&self) -> &[u8; 4] { + &self.epoch + } + /// Decrypt a payload. Input is ciphertext + auth_tag (16 bytes at end). /// + /// - `epoch`: the epoch that was used during encryption (from the segment header) /// - `lsn`: must match the LSN used during encryption /// - `header_bytes`: must match the header used during encryption (AAD) /// - `ciphertext`: the encrypted payload (includes 16-byte auth tag) pub fn decrypt( &self, + epoch: &[u8; 4], lsn: u64, header_bytes: &[u8; HEADER_SIZE], ciphertext: &[u8], ) -> Result> { - let nonce = lsn_to_nonce(lsn); + let nonce = lsn_to_nonce(epoch, lsn); self.cipher .decrypt( &nonce, @@ -140,27 +158,23 @@ impl KeyRing { /// Decrypt: try current key first, then previous (if set). /// + /// `epoch` is the encryption epoch stored in the WAL segment header. /// This enables seamless key rotation — old data encrypted with the /// previous key can still be read while new data uses the current key. pub fn decrypt( &self, + epoch: &[u8; 4], lsn: u64, header_bytes: &[u8; HEADER_SIZE], ciphertext: &[u8], ) -> Result> { - match self.current.decrypt(lsn, header_bytes, ciphertext) { - Ok(plaintext) => Ok(plaintext), - Err(_) if self.previous.is_some() => { - // Current key failed — try previous key. - if let Some(prev) = self.previous.as_ref() { - prev.decrypt(lsn, header_bytes, ciphertext) - } else { - Err(crate::error::WalError::EncryptionError { - detail: "key rotation state inconsistent".into(), - }) - } - } - Err(e) => Err(e), + match ( + self.current.decrypt(epoch, lsn, header_bytes, ciphertext), + self.previous.as_ref(), + ) { + (Ok(plaintext), _) => Ok(plaintext), + (Err(_), Some(prev)) => prev.decrypt(epoch, lsn, header_bytes, ciphertext), + (Err(e), None) => Err(e), } } @@ -183,14 +197,16 @@ impl KeyRing { /// AES-256-GCM auth tag size in bytes. pub const AUTH_TAG_SIZE: usize = 16; -/// Derive a 12-byte nonce from an LSN. +/// Derive a 12-byte nonce from an epoch and LSN. /// -/// AES-256-GCM requires a 96-bit (12 byte) nonce. Since LSNs are monotonically -/// increasing and globally unique, they make ideal deterministic nonces. -/// We zero-pad the 8-byte LSN to 12 bytes. -fn lsn_to_nonce(lsn: u64) -> aes_gcm::Nonce { +/// AES-256-GCM requires a 96-bit (12 byte) nonce that must never repeat +/// for the same key. Layout: `[4-byte random epoch][8-byte LSN]`. +/// The epoch is generated randomly per WAL lifetime, so even if LSNs +/// restart from 1 after a snapshot restore, the nonces remain unique. +fn lsn_to_nonce(epoch: &[u8; 4], lsn: u64) -> aes_gcm::Nonce { let mut nonce_bytes = [0u8; 12]; - nonce_bytes[..8].copy_from_slice(&lsn.to_le_bytes()); + nonce_bytes[..4].copy_from_slice(epoch); + nonce_bytes[4..12].copy_from_slice(&lsn.to_le_bytes()); nonce_bytes.into() } @@ -211,6 +227,7 @@ mod tests { #[test] fn encrypt_decrypt_roundtrip() { let key = test_key(); + let epoch = key.epoch().clone(); let header = test_header(1); let plaintext = b"hello nodedb encryption"; @@ -218,43 +235,47 @@ mod tests { assert_ne!(&ciphertext[..plaintext.len()], plaintext); assert_eq!(ciphertext.len(), plaintext.len() + AUTH_TAG_SIZE); - let decrypted = key.decrypt(1, &header, &ciphertext).unwrap(); + let decrypted = key.decrypt(&epoch, 1, &header, &ciphertext).unwrap(); assert_eq!(decrypted, plaintext); } #[test] fn wrong_key_fails() { let key1 = WalEncryptionKey::from_bytes(&[0x01; 32]); + let epoch1 = key1.epoch().clone(); let key2 = WalEncryptionKey::from_bytes(&[0x02; 32]); let header = test_header(1); let ciphertext = key1.encrypt(1, &header, b"secret").unwrap(); - assert!(key2.decrypt(1, &header, &ciphertext).is_err()); + assert!(key2.decrypt(&epoch1, 1, &header, &ciphertext).is_err()); } #[test] fn wrong_lsn_fails() { let key = test_key(); + let epoch = key.epoch().clone(); let header = test_header(1); let ciphertext = key.encrypt(1, &header, b"secret").unwrap(); // Different LSN = different nonce = decryption fails. - assert!(key.decrypt(2, &header, &ciphertext).is_err()); + assert!(key.decrypt(&epoch, 2, &header, &ciphertext).is_err()); } #[test] fn tampered_ciphertext_fails() { let key = test_key(); + let epoch = key.epoch().clone(); let header = test_header(1); let mut ciphertext = key.encrypt(1, &header, b"secret").unwrap(); ciphertext[0] ^= 0xFF; - assert!(key.decrypt(1, &header, &ciphertext).is_err()); + assert!(key.decrypt(&epoch, 1, &header, &ciphertext).is_err()); } #[test] fn tampered_header_fails() { let key = test_key(); + let epoch = key.epoch().clone(); let header1 = test_header(1); let ciphertext = key.encrypt(1, &header1, b"secret").unwrap(); @@ -262,18 +283,19 @@ mod tests { // Tamper the AAD (header). let mut header2 = header1; header2[0] = 0xFF; - assert!(key.decrypt(1, &header2, &ciphertext).is_err()); + assert!(key.decrypt(&epoch, 1, &header2, &ciphertext).is_err()); } #[test] fn empty_payload() { let key = test_key(); + let epoch = key.epoch().clone(); let header = test_header(1); let ciphertext = key.encrypt(1, &header, b"").unwrap(); assert_eq!(ciphertext.len(), AUTH_TAG_SIZE); // Just the tag. - let decrypted = key.decrypt(1, &header, &ciphertext).unwrap(); + let decrypted = key.decrypt(&epoch, 1, &header, &ciphertext).unwrap(); assert!(decrypted.is_empty()); } @@ -286,4 +308,27 @@ mod tests { let ct2 = key.encrypt(2, &test_header(2), plaintext).unwrap(); assert_ne!(ct1, ct2); } + + #[test] + fn same_lsn_different_wal_lifetimes_produce_different_ciphertext() { + // Simulate two WAL lifetimes: same key bytes, same LSN=1, but + // separate WalEncryptionKey instances (each gets a fresh random epoch). + // This models: write at LSN=1, wipe WAL, restart with same key, + // write at LSN=1 again. The two ciphertexts must differ. + let key_bytes = [0x42u8; 32]; + let key1 = WalEncryptionKey::from_bytes(&key_bytes); + let key2 = WalEncryptionKey::from_bytes(&key_bytes); + let header = test_header(1); + let pt = b"same plaintext in two wal lifetimes"; + + let ct1 = key1.encrypt(1, &header, pt).unwrap(); + let ct2 = key2.encrypt(1, &header, pt).unwrap(); + + // SPEC: different WAL lifetimes (different epochs) must produce + // different ciphertext even with the same key bytes and LSN. + assert_ne!( + ct1, ct2, + "nonce reuse: same (key_bytes, lsn) must not produce identical ciphertext across WAL lifetimes" + ); + } } diff --git a/nodedb-wal/src/mmap_reader.rs b/nodedb-wal/src/mmap_reader.rs index 2fe4160c..5b6424d8 100644 --- a/nodedb-wal/src/mmap_reader.rs +++ b/nodedb-wal/src/mmap_reader.rs @@ -49,65 +49,67 @@ impl MmapWalReader { pub fn next_record(&mut self) -> Result> { let data = &self.mmap[..]; - // Check if we have enough bytes for a header. - if self.offset + HEADER_SIZE > data.len() { - return Ok(None); - } + loop { + // Check if we have enough bytes for a header. + if self.offset + HEADER_SIZE > data.len() { + return Ok(None); + } - // Parse header. - let header_bytes: &[u8; HEADER_SIZE] = data[self.offset..self.offset + HEADER_SIZE] - .try_into() - .map_err(|_| { - WalError::Io(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "header slice conversion failed", - )) - })?; - let header = RecordHeader::from_bytes(header_bytes); - - // Validate magic — corruption or end of valid data. - if header.magic != WAL_MAGIC { - return Ok(None); - } + // Parse header. + let header_bytes: &[u8; HEADER_SIZE] = data[self.offset..self.offset + HEADER_SIZE] + .try_into() + .map_err(|_| { + WalError::Io(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "header slice conversion failed", + )) + })?; + let header = RecordHeader::from_bytes(header_bytes); + + // Validate magic — corruption or end of valid data. + if header.magic != WAL_MAGIC { + return Ok(None); + } - // Validate version. - if header.validate(self.offset as u64).is_err() { - return Ok(None); - } + // Validate version. + if header.validate(self.offset as u64).is_err() { + return Ok(None); + } - let payload_len = header.payload_len as usize; - let record_end = self.offset + HEADER_SIZE + payload_len; + let payload_len = header.payload_len as usize; + let record_end = self.offset + HEADER_SIZE + payload_len; - // Check if payload is fully within the mmap'd region. - if record_end > data.len() { - return Ok(None); // Torn write at segment end. - } + // Check if payload is fully within the mmap'd region. + if record_end > data.len() { + return Ok(None); // Torn write at segment end. + } - // Extract payload (copies from mmap to owned Vec). - let payload = data[self.offset + HEADER_SIZE..record_end].to_vec(); - self.offset = record_end; + // Extract payload (copies from mmap to owned Vec). + let payload = data[self.offset + HEADER_SIZE..record_end].to_vec(); + self.offset = record_end; - let record = WalRecord { header, payload }; + let record = WalRecord { header, payload }; - // Verify checksum. - if record.verify_checksum().is_err() { - return Ok(None); // Corruption — end of committed prefix. - } + // Verify checksum. + if record.verify_checksum().is_err() { + return Ok(None); // Corruption — end of committed prefix. + } - // Check record type. - let logical_type = record.logical_record_type(); - if RecordType::from_raw(logical_type).is_none() { - if RecordType::is_required(logical_type) { - return Err(WalError::UnknownRequiredRecordType { - record_type: header.record_type, - lsn: header.lsn, - }); + // Check record type. + let logical_type = record.logical_record_type(); + if RecordType::from_raw(logical_type).is_none() { + if RecordType::is_required(logical_type) { + return Err(WalError::UnknownRequiredRecordType { + record_type: header.record_type, + lsn: header.lsn, + }); + } + // Unknown optional record — skip and continue loop. + continue; } - // Unknown optional record — skip and continue. - return self.next_record(); - } - Ok(Some(record)) + return Ok(Some(record)); + } } /// Iterator over all valid records in the mmap'd segment. diff --git a/nodedb-wal/src/reader.rs b/nodedb-wal/src/reader.rs index 5936976b..b1914f41 100644 --- a/nodedb-wal/src/reader.rs +++ b/nodedb-wal/src/reader.rs @@ -53,72 +53,69 @@ impl WalReader { /// Returns `None` at EOF (clean end) or at the first corruption point. /// Returns `Err` only for I/O errors or unknown required record types. pub fn next_record(&mut self) -> Result> { - // Read header. - let mut header_buf = [0u8; HEADER_SIZE]; - match self.read_exact(&mut header_buf) { - Ok(()) => {} - Err(WalError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { - return Ok(None); // Clean EOF. + loop { + // Read header. + let mut header_buf = [0u8; HEADER_SIZE]; + match self.read_exact(&mut header_buf) { + Ok(()) => {} + Err(WalError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + return Ok(None); // Clean EOF. + } + Err(e) => return Err(e), } - Err(e) => return Err(e), - } - let header = RecordHeader::from_bytes(&header_buf); + let header = RecordHeader::from_bytes(&header_buf); - // Validate magic and version. - if header.validate(self.offset - HEADER_SIZE as u64).is_err() { - // Corruption or end of valid data — treat as end of committed prefix. - return Ok(None); - } + // Validate magic and version. + if header.validate(self.offset - HEADER_SIZE as u64).is_err() { + // Corruption or end of valid data — treat as end of committed prefix. + return Ok(None); + } - // Read payload. - let mut payload = vec![0u8; header.payload_len as usize]; - if !payload.is_empty() { - match self.read_exact(&mut payload) { - Ok(()) => {} - Err(WalError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { - // Torn write — record is incomplete. This is the end of committed prefix. - return Ok(None); + // Read payload. + let mut payload = vec![0u8; header.payload_len as usize]; + if !payload.is_empty() { + match self.read_exact(&mut payload) { + Ok(()) => {} + Err(WalError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + return Ok(None); + } + Err(e) => return Err(e), } - Err(e) => return Err(e), } - } - let record = WalRecord { header, payload }; - - // Verify checksum. - if record.verify_checksum().is_err() { - // Checksum mismatch — torn write or corruption. - // Try to recover from double-write buffer if available. - if let Some(dwb) = &mut self.double_write - && let Ok(Some(recovered)) = dwb.recover_record(header.lsn) - { - tracing::info!( - lsn = header.lsn, - "recovered torn write from double-write buffer" - ); - self.offset += recovered.payload.len() as u64; - return Ok(Some(recovered)); + let record = WalRecord { header, payload }; + + // Verify checksum. + if record.verify_checksum().is_err() { + if let Some(dwb) = &mut self.double_write + && let Ok(Some(recovered)) = dwb.recover_record(header.lsn) + { + tracing::info!( + lsn = header.lsn, + "recovered torn write from double-write buffer" + ); + self.offset += recovered.payload.len() as u64; + return Ok(Some(recovered)); + } + return Ok(None); } - // No DWB recovery possible — end of committed prefix. - return Ok(None); - } - // Check if the record type is known (strip encrypted flag for lookup). - let logical_type = record.logical_record_type(); - if RecordType::from_raw(logical_type).is_none() { - if RecordType::is_required(logical_type) { - return Err(WalError::UnknownRequiredRecordType { - record_type: header.record_type, - lsn: header.lsn, - }); + // Check if the record type is known (strip encrypted flag for lookup). + let logical_type = record.logical_record_type(); + if RecordType::from_raw(logical_type).is_none() { + if RecordType::is_required(logical_type) { + return Err(WalError::UnknownRequiredRecordType { + record_type: header.record_type, + lsn: header.lsn, + }); + } + // Unknown optional record — skip and continue loop. + continue; } - // Unknown optional record — skip it and continue. - // (The record is already consumed, so just recurse.) - return self.next_record(); - } - Ok(Some(record)) + return Ok(Some(record)); + } } /// Iterator over all valid records in the WAL. @@ -246,4 +243,39 @@ mod tests { assert_eq!(records.len(), 1); assert_eq!(records[0].payload, b"good-record"); } + + #[test] + fn skip_many_unknown_optional_records_is_iterative() { + // Record type 99 has bit 15 clear (99 & 0x8000 == 0) and is not a + // known variant, so the reader must skip it as an unknown optional. + // With the current recursive implementation (line 118: `return + // self.next_record()`), 50 000 consecutive unknown optional records + // exhaust the stack and panic. After the fix converts the skip to a + // loop, all 50 000 are skipped without overflow and the one valid + // record at the end is returned. + const UNKNOWN_OPTIONAL: u16 = 99; // no 0x8000 bit → optional, not in enum + const SKIP_COUNT: usize = 50_000; + + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("many_unknown.wal"); + + { + let mut writer = WalWriter::open_without_direct_io(&path).unwrap(); + for _ in 0..SKIP_COUNT { + writer.append(UNKNOWN_OPTIONAL, 1, 0, b"skip-me").unwrap(); + } + writer + .append(RecordType::Put as u16, 1, 0, b"keep-me") + .unwrap(); + writer.sync().unwrap(); + } + + let reader = WalReader::open(&path).unwrap(); + let records: Vec<_> = reader.records().collect::>().unwrap(); + + // Only the single known Put record survives; all unknown optional + // records are silently discarded. + assert_eq!(records.len(), 1); + assert_eq!(records[0].payload, b"keep-me"); + } } diff --git a/nodedb-wal/src/record.rs b/nodedb-wal/src/record.rs index b89fa0bf..7431cb3e 100644 --- a/nodedb-wal/src/record.rs +++ b/nodedb-wal/src/record.rs @@ -263,9 +263,11 @@ impl WalRecord { /// Decrypt the payload if the record is encrypted. /// + /// `epoch` is the encryption epoch from the WAL segment header. /// Returns the plaintext payload. If not encrypted, returns the payload as-is. pub fn decrypt_payload( &self, + epoch: &[u8; 4], encryption_key: Option<&crate::crypto::WalEncryptionKey>, ) -> Result> { if !self.is_encrypted() { @@ -284,14 +286,19 @@ impl WalRecord { aad_header.crc32c = 0; let header_bytes = aad_header.to_bytes(); - key.decrypt(self.header.lsn, &header_bytes, &self.payload) + key.decrypt(epoch, self.header.lsn, &header_bytes, &self.payload) } /// Decrypt the payload using a key ring (supports dual-key rotation). /// + /// `epoch` is the encryption epoch from the WAL segment header. /// Tries the current key first, then falls back to the previous key. /// Returns the plaintext payload. If not encrypted, returns the payload as-is. - pub fn decrypt_payload_ring(&self, ring: Option<&crate::crypto::KeyRing>) -> Result> { + pub fn decrypt_payload_ring( + &self, + epoch: &[u8; 4], + ring: Option<&crate::crypto::KeyRing>, + ) -> Result> { if !self.is_encrypted() { return Ok(self.payload.clone()); } @@ -306,7 +313,7 @@ impl WalRecord { aad_header.crc32c = 0; let header_bytes = aad_header.to_bytes(); - ring.decrypt(self.header.lsn, &header_bytes, &self.payload) + ring.decrypt(epoch, self.header.lsn, &header_bytes, &self.payload) } /// Whether this record's payload is encrypted. diff --git a/nodedb-wal/src/writer.rs b/nodedb-wal/src/writer.rs index 57e03721..ec541487 100644 --- a/nodedb-wal/src/writer.rs +++ b/nodedb-wal/src/writer.rs @@ -316,25 +316,32 @@ impl WalWriter { self.buffer.as_slice() }; - // Use pwrite to write at the exact offset. + // Use pwrite to write at the exact offset, retrying on short writes. #[cfg(unix)] { use std::os::unix::io::AsRawFd; let fd = self.file.as_raw_fd(); - let written = unsafe { - libc::pwrite( - fd, - data.as_ptr() as *const libc::c_void, - data.len(), - self.file_offset as libc::off_t, - ) - }; - if written < 0 { - return Err(WalError::Io(std::io::Error::last_os_error())); + let mut remaining = data; + let mut write_offset = self.file_offset; + while !remaining.is_empty() { + let written = unsafe { + libc::pwrite( + fd, + remaining.as_ptr() as *const libc::c_void, + remaining.len(), + write_offset as libc::off_t, + ) + }; + if written < 0 { + return Err(WalError::Io(std::io::Error::last_os_error())); + } + let n = written as usize; + remaining = &remaining[n..]; + write_offset += n as u64; } } - self.file_offset += self.buffer.len() as u64; + self.file_offset += data.len() as u64; self.buffer.clear(); Ok(()) } diff --git a/nodedb/src/storage/segment.rs b/nodedb/src/storage/segment.rs index 2067a4bd..b4fee287 100644 --- a/nodedb/src/storage/segment.rs +++ b/nodedb/src/storage/segment.rs @@ -176,7 +176,7 @@ pub fn read_encrypted_segment( if let Some(key) = key { let mut aad = [0u8; nodedb_wal::record::HEADER_SIZE]; aad[..4].copy_from_slice(b"SEGM"); - key.decrypt(footer.min_lsn.as_u64(), &aad, data) + key.decrypt(key.epoch(), footer.min_lsn.as_u64(), &aad, data) .map_err(|e| crate::Error::Storage { engine: "segment".into(), detail: format!("segment decryption failed: {e}"), From 61b7299e98cec3a14cbc7f84b7b70c07a2d668ff Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Thu, 16 Apr 2026 19:02:12 +0800 Subject: [PATCH 02/10] fix(security): harden inputs against resource exhaustion attacks - Add 10-second timeout on TLS handshakes across all listeners (pgwire, native, RESP, ILP) to prevent slow-handshake connection slot exhaustion - Enforce a 10 MiB per-line limit on ILP ingestion using read_until instead of lines(), dropping connections that exceed it - Limit recursive expression parsing depth to 128 in both the SQL resolver and the generated-expression parser to prevent stack overflow on deeply-nested malformed ASTs - Cap ef_search at 8192 in HNSW vector search to prevent DoS via unbounded beam width --- nodedb-query/src/expr_parse.rs | 103 ++++++++++++------ nodedb-sql/src/resolver/expr.rs | 95 +++++++++++----- nodedb/src/control/server/ilp_listener.rs | 62 +++++++++-- nodedb/src/control/server/listener.rs | 14 ++- nodedb/src/control/server/resp/listener.rs | 14 ++- .../data/executor/handlers/vector_search.rs | 7 +- 6 files changed, 213 insertions(+), 82 deletions(-) diff --git a/nodedb-query/src/expr_parse.rs b/nodedb-query/src/expr_parse.rs index 4ccd9058..4c4f4762 100644 --- a/nodedb-query/src/expr_parse.rs +++ b/nodedb-query/src/expr_parse.rs @@ -26,7 +26,7 @@ use nodedb_types::Value; pub fn parse_generated_expr(text: &str) -> Result<(SqlExpr, Vec), String> { let tokens = tokenize(text)?; let mut pos = 0; - let expr = parse_expr(&tokens, &mut pos)?; + let expr = parse_expr(&tokens, &mut pos, &mut 0)?; if pos < tokens.len() { return Err(format!( "unexpected token after expression: '{}'", @@ -195,16 +195,20 @@ fn tokenize(input: &str) -> Result, String> { // ── Recursive descent parser ────────────────────────────────────────── +/// Maximum recursion depth for nested parentheses / sub-expressions. +/// Exceeding this limit returns `Err` instead of overflowing the stack. +const MAX_EXPR_DEPTH: usize = 128; + /// Parse an expression (lowest precedence: OR). -fn parse_expr(tokens: &[Token], pos: &mut usize) -> Result { - parse_or(tokens, pos) +fn parse_expr(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result { + parse_or(tokens, pos, depth) } -fn parse_or(tokens: &[Token], pos: &mut usize) -> Result { - let mut left = parse_and(tokens, pos)?; +fn parse_or(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result { + let mut left = parse_and(tokens, pos, depth)?; while peek_keyword(tokens, *pos, "OR") { *pos += 1; - let right = parse_and(tokens, pos)?; + let right = parse_and(tokens, pos, depth)?; left = SqlExpr::BinaryOp { left: Box::new(left), op: BinaryOp::Or, @@ -214,11 +218,11 @@ fn parse_or(tokens: &[Token], pos: &mut usize) -> Result { Ok(left) } -fn parse_and(tokens: &[Token], pos: &mut usize) -> Result { - let mut left = parse_comparison(tokens, pos)?; +fn parse_and(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result { + let mut left = parse_comparison(tokens, pos, depth)?; while peek_keyword(tokens, *pos, "AND") { *pos += 1; - let right = parse_comparison(tokens, pos)?; + let right = parse_comparison(tokens, pos, depth)?; left = SqlExpr::BinaryOp { left: Box::new(left), op: BinaryOp::And, @@ -228,8 +232,12 @@ fn parse_and(tokens: &[Token], pos: &mut usize) -> Result { Ok(left) } -fn parse_comparison(tokens: &[Token], pos: &mut usize) -> Result { - let left = parse_additive(tokens, pos)?; +fn parse_comparison( + tokens: &[Token], + pos: &mut usize, + depth: &mut usize, +) -> Result { + let left = parse_additive(tokens, pos, depth)?; if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op { let op = match tokens[*pos].text.as_str() { "=" => BinaryOp::Eq, @@ -241,7 +249,7 @@ fn parse_comparison(tokens: &[Token], pos: &mut usize) -> Result return Ok(left), }; *pos += 1; - let right = parse_additive(tokens, pos)?; + let right = parse_additive(tokens, pos, depth)?; return Ok(SqlExpr::BinaryOp { left: Box::new(left), op, @@ -251,8 +259,8 @@ fn parse_comparison(tokens: &[Token], pos: &mut usize) -> Result Result { - let mut left = parse_multiplicative(tokens, pos)?; +fn parse_additive(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result { + let mut left = parse_multiplicative(tokens, pos, depth)?; while *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op { let op = match tokens[*pos].text.as_str() { "+" => BinaryOp::Add, @@ -261,7 +269,7 @@ fn parse_additive(tokens: &[Token], pos: &mut usize) -> Result _ => break, }; *pos += 1; - let right = parse_multiplicative(tokens, pos)?; + let right = parse_multiplicative(tokens, pos, depth)?; left = SqlExpr::BinaryOp { left: Box::new(left), op, @@ -271,8 +279,12 @@ fn parse_additive(tokens: &[Token], pos: &mut usize) -> Result Ok(left) } -fn parse_multiplicative(tokens: &[Token], pos: &mut usize) -> Result { - let mut left = parse_unary(tokens, pos)?; +fn parse_multiplicative( + tokens: &[Token], + pos: &mut usize, + depth: &mut usize, +) -> Result { + let mut left = parse_unary(tokens, pos, depth)?; while *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op { let op = match tokens[*pos].text.as_str() { "*" => BinaryOp::Mul, @@ -281,7 +293,7 @@ fn parse_multiplicative(tokens: &[Token], pos: &mut usize) -> Result break, }; *pos += 1; - let right = parse_unary(tokens, pos)?; + let right = parse_unary(tokens, pos, depth)?; left = SqlExpr::BinaryOp { left: Box::new(left), op, @@ -291,23 +303,23 @@ fn parse_multiplicative(tokens: &[Token], pos: &mut usize) -> Result Result { +fn parse_unary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result { // Unary minus. if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op && tokens[*pos].text == "-" { *pos += 1; - let expr = parse_primary(tokens, pos)?; + let expr = parse_primary(tokens, pos, depth)?; return Ok(SqlExpr::Negate(Box::new(expr))); } // NOT if peek_keyword(tokens, *pos, "NOT") { *pos += 1; - let expr = parse_primary(tokens, pos)?; + let expr = parse_primary(tokens, pos, depth)?; return Ok(SqlExpr::Negate(Box::new(expr))); } - parse_primary(tokens, pos) + parse_primary(tokens, pos, depth) } -fn parse_primary(tokens: &[Token], pos: &mut usize) -> Result { +fn parse_primary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result { if *pos >= tokens.len() { return Err("unexpected end of expression".into()); } @@ -317,8 +329,15 @@ fn parse_primary(tokens: &[Token], pos: &mut usize) -> Result { match token.kind { // Parenthesized expression. TokenKind::LParen => { + *depth += 1; + if *depth > MAX_EXPR_DEPTH { + return Err(format!( + "expression nesting depth exceeds maximum of {MAX_EXPR_DEPTH}" + )); + } *pos += 1; - let expr = parse_expr(tokens, pos)?; + let expr = parse_expr(tokens, pos, depth)?; + *depth -= 1; expect_token(tokens, pos, TokenKind::RParen, ")")?; Ok(expr) } @@ -351,15 +370,15 @@ fn parse_primary(tokens: &[Token], pos: &mut usize) -> Result { "NULL" => Ok(SqlExpr::Literal(Value::Null)), "TRUE" => Ok(SqlExpr::Literal(Value::Bool(true))), "FALSE" => Ok(SqlExpr::Literal(Value::Bool(false))), - "CASE" => parse_case(tokens, pos), + "CASE" => parse_case(tokens, pos, depth), "COALESCE" => { - let args = parse_arg_list(tokens, pos)?; + let args = parse_arg_list(tokens, pos, depth)?; Ok(SqlExpr::Coalesce(args)) } _ => { // Function call: IDENT(args). if *pos < tokens.len() && tokens[*pos].kind == TokenKind::LParen { - let args = parse_arg_list(tokens, pos)?; + let args = parse_arg_list(tokens, pos, depth)?; Ok(SqlExpr::Function { name: name.to_lowercase(), args, @@ -377,20 +396,20 @@ fn parse_primary(tokens: &[Token], pos: &mut usize) -> Result { } /// Parse `CASE WHEN cond THEN result [WHEN ... THEN ...] [ELSE result] END`. -fn parse_case(tokens: &[Token], pos: &mut usize) -> Result { +fn parse_case(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result { let mut when_thens = Vec::new(); let mut else_expr = None; loop { if peek_keyword(tokens, *pos, "WHEN") { *pos += 1; - let cond = parse_expr(tokens, pos)?; + let cond = parse_expr(tokens, pos, depth)?; expect_keyword(tokens, pos, "THEN")?; - let then = parse_expr(tokens, pos)?; + let then = parse_expr(tokens, pos, depth)?; when_thens.push((cond, then)); } else if peek_keyword(tokens, *pos, "ELSE") { *pos += 1; - else_expr = Some(Box::new(parse_expr(tokens, pos)?)); + else_expr = Some(Box::new(parse_expr(tokens, pos, depth)?)); } else if peek_keyword(tokens, *pos, "END") { *pos += 1; break; @@ -411,7 +430,11 @@ fn parse_case(tokens: &[Token], pos: &mut usize) -> Result { } /// Parse a parenthesized, comma-separated argument list: `(expr, expr, ...)`. -fn parse_arg_list(tokens: &[Token], pos: &mut usize) -> Result, String> { +fn parse_arg_list( + tokens: &[Token], + pos: &mut usize, + depth: &mut usize, +) -> Result, String> { expect_token(tokens, pos, TokenKind::LParen, "(")?; let mut args = Vec::new(); if *pos < tokens.len() && tokens[*pos].kind == TokenKind::RParen { @@ -419,7 +442,7 @@ fn parse_arg_list(tokens: &[Token], pos: &mut usize) -> Result, Str return Ok(args); } loop { - args.push(parse_expr(tokens, pos)?); + args.push(parse_expr(tokens, pos, depth)?); if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Comma { *pos += 1; } else { @@ -667,4 +690,18 @@ mod tests { let doc = Value::from(serde_json::json!({"price": 49.99})); assert_eq!(expr.eval(&doc), Value::Float(49.99)); } + + #[test] + fn deeply_nested_parentheses_return_error_not_stack_overflow() { + // Spec: the parser must enforce a recursion depth limit so that + // pathologically deep nesting returns Err rather than overflowing the + // call stack and causing a process crash. + let depth = 10_000; + let input = format!("{}x{}", "(".repeat(depth), ")".repeat(depth),); + let result = parse_generated_expr(&input); + assert!( + result.is_err(), + "parse_generated_expr must return Err for {depth}-deep nesting, not stack overflow" + ); + } } diff --git a/nodedb-sql/src/resolver/expr.rs b/nodedb-sql/src/resolver/expr.rs index 99ed7772..be2d3bb6 100644 --- a/nodedb-sql/src/resolver/expr.rs +++ b/nodedb-sql/src/resolver/expr.rs @@ -6,8 +6,30 @@ use crate::error::{Result, SqlError}; use crate::parser::normalize::normalize_ident; use crate::types::*; +/// Maximum AST nesting depth accepted by `convert_expr`. +/// Exceeding this limit returns `Err` instead of overflowing the stack. +const MAX_CONVERT_DEPTH: usize = 128; + /// Convert a sqlparser `Expr` to our `SqlExpr`. pub fn convert_expr(expr: &Expr) -> Result { + convert_expr_depth(expr, &mut 0) +} + +/// Internal recursive helper that carries a depth counter to enforce +/// `MAX_CONVERT_DEPTH` and prevent stack overflow on malformed ASTs. +fn convert_expr_depth(expr: &Expr, depth: &mut usize) -> Result { + *depth += 1; + if *depth > MAX_CONVERT_DEPTH { + return Err(SqlError::Unsupported { + detail: format!("expression nesting depth exceeds maximum of {MAX_CONVERT_DEPTH}"), + }); + } + let result = convert_expr_inner(expr, depth); + *depth -= 1; + result +} + +fn convert_expr_inner(expr: &Expr, depth: &mut usize) -> Result { match expr { Expr::Identifier(ident) => Ok(SqlExpr::Column { table: None, @@ -19,22 +41,22 @@ pub fn convert_expr(expr: &Expr) -> Result { }), Expr::Value(val) => Ok(SqlExpr::Literal(convert_value(&val.value)?)), Expr::BinaryOp { left, op, right } => Ok(SqlExpr::BinaryOp { - left: Box::new(convert_expr(left)?), + left: Box::new(convert_expr_depth(left, depth)?), op: convert_binary_op(op)?, - right: Box::new(convert_expr(right)?), + right: Box::new(convert_expr_depth(right, depth)?), }), Expr::UnaryOp { op, expr } => Ok(SqlExpr::UnaryOp { op: convert_unary_op(op)?, - expr: Box::new(convert_expr(expr)?), + expr: Box::new(convert_expr_depth(expr, depth)?), }), - Expr::Function(func) => convert_function(func), - Expr::Nested(inner) => convert_expr(inner), + Expr::Function(func) => convert_function_depth(func, depth), + Expr::Nested(inner) => convert_expr_depth(inner, depth), Expr::IsNull(inner) => Ok(SqlExpr::IsNull { - expr: Box::new(convert_expr(inner)?), + expr: Box::new(convert_expr_depth(inner, depth)?), negated: false, }), Expr::IsNotNull(inner) => Ok(SqlExpr::IsNull { - expr: Box::new(convert_expr(inner)?), + expr: Box::new(convert_expr_depth(inner, depth)?), negated: true, }), Expr::InList { @@ -42,8 +64,11 @@ pub fn convert_expr(expr: &Expr) -> Result { list, negated, } => Ok(SqlExpr::InList { - expr: Box::new(convert_expr(expr)?), - list: list.iter().map(convert_expr).collect::>()?, + expr: Box::new(convert_expr_depth(expr, depth)?), + list: list + .iter() + .map(|e| convert_expr_depth(e, depth)) + .collect::>()?, negated: *negated, }), Expr::Between { @@ -52,9 +77,9 @@ pub fn convert_expr(expr: &Expr) -> Result { high, negated, } => Ok(SqlExpr::Between { - expr: Box::new(convert_expr(expr)?), - low: Box::new(convert_expr(low)?), - high: Box::new(convert_expr(high)?), + expr: Box::new(convert_expr_depth(expr, depth)?), + low: Box::new(convert_expr_depth(low, depth)?), + high: Box::new(convert_expr_depth(high, depth)?), negated: *negated, }), Expr::Like { @@ -63,8 +88,8 @@ pub fn convert_expr(expr: &Expr) -> Result { negated, .. } => Ok(SqlExpr::Like { - expr: Box::new(convert_expr(expr)?), - pattern: Box::new(convert_expr(pattern)?), + expr: Box::new(convert_expr_depth(expr, depth)?), + pattern: Box::new(convert_expr_depth(pattern, depth)?), negated: *negated, }), Expr::ILike { @@ -73,8 +98,8 @@ pub fn convert_expr(expr: &Expr) -> Result { negated, .. } => Ok(SqlExpr::Like { - expr: Box::new(convert_expr(expr)?), - pattern: Box::new(convert_expr(pattern)?), + expr: Box::new(convert_expr_depth(expr, depth)?), + pattern: Box::new(convert_expr_depth(pattern, depth)?), negated: *negated, }), Expr::Case { @@ -85,46 +110,54 @@ pub fn convert_expr(expr: &Expr) -> Result { } => { let when_then = conditions .iter() - .map(|cw| Ok((convert_expr(&cw.condition)?, convert_expr(&cw.result)?))) + .map(|cw| { + Ok(( + convert_expr_depth(&cw.condition, depth)?, + convert_expr_depth(&cw.result, depth)?, + )) + }) .collect::>>()?; Ok(SqlExpr::Case { operand: operand .as_ref() - .map(|e| convert_expr(e).map(Box::new)) + .map(|e| convert_expr_depth(e, depth).map(Box::new)) .transpose()?, when_then, else_expr: else_result .as_ref() - .map(|e| convert_expr(e).map(Box::new)) + .map(|e| convert_expr_depth(e, depth).map(Box::new)) .transpose()?, }) } Expr::Cast { expr, data_type, .. } => Ok(SqlExpr::Cast { - expr: Box::new(convert_expr(expr)?), + expr: Box::new(convert_expr_depth(expr, depth)?), to_type: format!("{data_type}"), }), Expr::Array(ast::Array { elem, .. }) => { - let elems = elem.iter().map(convert_expr).collect::>()?; + let elems = elem + .iter() + .map(|e| convert_expr_depth(e, depth)) + .collect::>()?; Ok(SqlExpr::ArrayLiteral(elems)) } Expr::Wildcard(_) => Ok(SqlExpr::Wildcard), // TRIM([BOTH|LEADING|TRAILING] [what FROM] expr) Expr::Trim { expr, .. } => Ok(SqlExpr::Function { name: "trim".into(), - args: vec![convert_expr(expr)?], + args: vec![convert_expr_depth(expr, depth)?], distinct: false, }), // CEIL(expr) / FLOOR(expr) Expr::Ceil { expr, .. } => Ok(SqlExpr::Function { name: "ceil".into(), - args: vec![convert_expr(expr)?], + args: vec![convert_expr_depth(expr, depth)?], distinct: false, }), Expr::Floor { expr, .. } => Ok(SqlExpr::Function { name: "floor".into(), - args: vec![convert_expr(expr)?], + args: vec![convert_expr_depth(expr, depth)?], distinct: false, }), // SUBSTRING(expr FROM start FOR len) @@ -134,12 +167,12 @@ pub fn convert_expr(expr: &Expr) -> Result { substring_for, .. } => { - let mut args = vec![convert_expr(expr)?]; + let mut args = vec![convert_expr_depth(expr, depth)?]; if let Some(from) = substring_from { - args.push(convert_expr(from)?); + args.push(convert_expr_depth(from, depth)?); } if let Some(len) = substring_for { - args.push(convert_expr(len)?); + args.push(convert_expr_depth(len, depth)?); } Ok(SqlExpr::Function { name: "substring".into(), @@ -241,7 +274,7 @@ pub fn convert_value(val: &Value) -> Result { } } -fn convert_function(func: &ast::Function) -> Result { +fn convert_function_depth(func: &ast::Function, depth: &mut usize) -> Result { let name = func .name .0 @@ -264,14 +297,16 @@ fn convert_function(func: &ast::Function) -> Result { .args .iter() .filter_map(|a| match a { - ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => Some(convert_expr(e)), + ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => { + Some(convert_expr_depth(e, depth)) + } ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard) => { Some(Ok(SqlExpr::Wildcard)) } ast::FunctionArg::Named { arg: ast::FunctionArgExpr::Expr(e), .. - } => Some(convert_expr(e)), + } => Some(convert_expr_depth(e, depth)), _ => None, }) .collect::>>()?, diff --git a/nodedb/src/control/server/ilp_listener.rs b/nodedb/src/control/server/ilp_listener.rs index 26dddd53..0d9c8079 100644 --- a/nodedb/src/control/server/ilp_listener.rs +++ b/nodedb/src/control/server/ilp_listener.rs @@ -12,6 +12,10 @@ use std::sync::Arc; use sonic_rs; use tokio::io::{AsyncBufReadExt, BufReader}; + +/// Maximum byte length of a single ILP line. Lines exceeding this are +/// rejected and the connection is dropped to prevent memory exhaustion. +const MAX_ILP_LINE_BYTES: usize = 10 * 1024 * 1024; // 10 MiB use tokio::net::TcpListener; use tokio::sync::Semaphore; use tracing::{debug, info, warn}; @@ -96,16 +100,24 @@ impl IlpListener { if let Some(ref acceptor) = tls_acceptor { let acceptor = acceptor.clone(); connections.spawn(async move { - match acceptor.accept(stream).await { - Ok(tls_stream) => { + match tokio::time::timeout( + std::time::Duration::from_secs(10), + acceptor.accept(stream), + ) + .await + { + Ok(Ok(tls_stream)) => { let cs = ConnStream::tls(tls_stream); if let Err(e) = handle_ilp_connection(cs, peer, &state).await { warn!(%peer, error = %e, "ILP TLS connection error (data may be lost)"); } } - Err(e) => { + Ok(Err(e)) => { warn!(%peer, error = %e, "ILP TLS handshake failed"); } + Err(_) => { + warn!(%peer, "ILP TLS handshake timed out"); + } } drop(permit); }); @@ -158,8 +170,8 @@ async fn handle_ilp_connection( ) -> crate::Result<()> { debug!(%peer, "ILP connection accepted"); - let reader = BufReader::new(stream); - let mut lines = reader.lines(); + let mut reader = BufReader::new(stream); + let mut line_buf: Vec = Vec::with_capacity(4096); let mut batch = String::new(); let mut line_count = 0u64; let mut total_ingested = 0u64; @@ -181,17 +193,46 @@ async fn handle_ilp_connection( loop { tokio::select! { - // Read next line. - result = lines.next_line() => { + // Read next line with an enforced byte-length cap. + result = reader.read_until(b'\n', &mut line_buf) => { match result { - Ok(Some(line)) => { + Ok(0) => break, // Connection closed (EOF). + Ok(_) => { + // Enforce line length limit before any allocation. + if line_buf.len() > MAX_ILP_LINE_BYTES { + warn!( + %peer, + len = line_buf.len(), + limit = MAX_ILP_LINE_BYTES, + "ILP line exceeds maximum length — dropping connection" + ); + break; + } + + // Strip trailing newline / CRLF. + let line_bytes = line_buf + .strip_suffix(b"\r\n") + .or_else(|| line_buf.strip_suffix(b"\n")) + .unwrap_or(&line_buf); + + let line = match std::str::from_utf8(line_bytes) { + Ok(s) => s, + Err(_) => { + warn!(%peer, "ILP line is not valid UTF-8 — skipping"); + line_buf.clear(); + continue; + } + }; + if line.is_empty() || line.starts_with('#') { + line_buf.clear(); continue; } - batch.push_str(&line); + batch.push_str(line); batch.push('\n'); line_count += 1; + line_buf.clear(); // Flush when batch reaches adaptive target. if line_count >= batch_target { @@ -212,8 +253,7 @@ async fn handle_ilp_connection( ); } } - Ok(None) => break, // Connection closed. - Err(_) => break, // Read error. + Err(_) => break, // Read error. } } // Timer-based flush (for low-rate connections). diff --git a/nodedb/src/control/server/listener.rs b/nodedb/src/control/server/listener.rs index a1424c96..19358372 100644 --- a/nodedb/src/control/server/listener.rs +++ b/nodedb/src/control/server/listener.rs @@ -120,16 +120,24 @@ impl Listener { if let Some(ref acceptor) = tls_acceptor { let acceptor = acceptor.clone(); connections.spawn(async move { - match acceptor.accept(stream).await { - Ok(tls_stream) => { + match tokio::time::timeout( + Duration::from_secs(10), + acceptor.accept(stream), + ) + .await + { + Ok(Ok(tls_stream)) => { let session = NativeSession::new_tls(tls_stream, peer_addr, state_clone, mode); if let Err(e) = session.run().await { warn!(%peer_addr, error = %e, "TLS session terminated with error"); } } - Err(e) => { + Ok(Err(e)) => { warn!(%peer_addr, error = %e, "native TLS handshake failed"); } + Err(_) => { + warn!(%peer_addr, "native TLS handshake timed out"); + } } // Permit is held for the session's lifetime and // released on drop when this future completes. diff --git a/nodedb/src/control/server/resp/listener.rs b/nodedb/src/control/server/resp/listener.rs index 7fc6b973..7e93e654 100644 --- a/nodedb/src/control/server/resp/listener.rs +++ b/nodedb/src/control/server/resp/listener.rs @@ -102,16 +102,24 @@ impl RespListener { if let Some(ref acceptor) = tls_acceptor { let acceptor = acceptor.clone(); connections.spawn(async move { - match acceptor.accept(stream).await { - Ok(tls_stream) => { + match tokio::time::timeout( + std::time::Duration::from_secs(10), + acceptor.accept(stream), + ) + .await + { + Ok(Ok(tls_stream)) => { let cs = ConnStream::tls(tls_stream); if let Err(e) = handle_connection(cs, peer, &state).await { debug!(%peer, error = %e, "RESP TLS connection error"); } } - Err(e) => { + Ok(Err(e)) => { warn!(%peer, error = %e, "RESP TLS handshake failed"); } + Err(_) => { + warn!(%peer, "RESP TLS handshake timed out"); + } } drop(permit); }); diff --git a/nodedb/src/data/executor/handlers/vector_search.rs b/nodedb/src/data/executor/handlers/vector_search.rs index 0c34619e..88941456 100644 --- a/nodedb/src/data/executor/handlers/vector_search.rs +++ b/nodedb/src/data/executor/handlers/vector_search.rs @@ -351,11 +351,14 @@ impl CoreLoop { } } +/// Maximum allowed ef_search value. Prevents DoS via unbounded beam width. +const MAX_EF_SEARCH: usize = 8192; + /// Compute effective ef parameter for HNSW search. fn effective_ef(ef_search: usize, top_k: usize) -> usize { if ef_search > 0 { - ef_search.max(top_k) + ef_search.max(top_k).min(MAX_EF_SEARCH) } else { - top_k.saturating_mul(4).max(64) + top_k.saturating_mul(4).max(64).min(MAX_EF_SEARCH) } } From 7c4bbb09ff13e6963062f896d0aa545150cae336 Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Thu, 16 Apr 2026 19:02:25 +0800 Subject: [PATCH 03/10] fix(vector): replace brute-force SQ8 candidate generation with HNSW search The quantized two-phase search was scanning all vectors to build candidates instead of using the HNSW graph. Replace the O(N) brute-force loop with HNSW graph traversal for O(log N) candidate generation, then rerank with exact FP32 distance as before. Also extend mmap_segment with additional helper methods and add unit tests for SQ8 search correctness via the collection API. --- nodedb-vector/src/collection/search.rs | 131 +++++++++++++++++----- nodedb-vector/src/hnsw/search.rs | 4 +- nodedb-vector/src/mmap_segment.rs | 146 ++++++++++++++++++++++++- 3 files changed, 249 insertions(+), 32 deletions(-) diff --git a/nodedb-vector/src/collection/search.rs b/nodedb-vector/src/collection/search.rs index d3d2712e..060f5b31 100644 --- a/nodedb-vector/src/collection/search.rs +++ b/nodedb-vector/src/collection/search.rs @@ -1,6 +1,6 @@ //! VectorCollection search: multi-segment merging with SQ8 reranking. -use crate::distance::{DistanceMetric, distance}; +use crate::distance::distance; use crate::hnsw::SearchResult; use super::lifecycle::VectorCollection; @@ -19,33 +19,15 @@ impl VectorCollection { // Search sealed segments. for seg in &self.sealed { - let results = if let Some((codec, sq8_data)) = &seg.sq8 { - // Quantized two-phase search. + let results = if let Some(_sq8) = &seg.sq8 { + // Quantized two-phase search: use HNSW graph for O(log N) candidate + // generation, then rerank with exact FP32 distance. let rerank_k = top_k.saturating_mul(3).max(20); - let mut candidates: Vec<(u32, f32)> = Vec::with_capacity(seg.index.len()); - let dim = seg.index.dim(); - for i in 0..seg.index.len() { - if seg.index.is_deleted(i as u32) { - continue; - } - let sq8_vec = &sq8_data[i * dim..(i + 1) * dim]; - let d = match self.params.metric { - DistanceMetric::L2 => codec.asymmetric_l2(query, sq8_vec), - DistanceMetric::Cosine => codec.asymmetric_cosine(query, sq8_vec), - DistanceMetric::InnerProduct => codec.asymmetric_ip(query, sq8_vec), - _ => { - let dequant = codec.dequantize(sq8_vec); - distance(query, &dequant, self.params.metric) - } - }; - candidates.push((i as u32, d)); - } - if candidates.len() > rerank_k { - candidates.select_nth_unstable_by(rerank_k, |a, b| { - a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal) - }); - candidates.truncate(rerank_k); - } + let hnsw_candidates = seg.index.search(query, rerank_k, ef); + let candidates: Vec<(u32, f32)> = hnsw_candidates + .into_iter() + .map(|r| (r.id, r.distance)) + .collect(); // Prefetch FP32 vectors for reranking candidates. if let Some(mmap) = &seg.mmap_vectors { @@ -252,4 +234,99 @@ mod tests { let results = coll.search(&[5.0, 0.0], 10, 64); assert!(results.iter().all(|r| r.id != 5)); } + + /// Build a sealed HNSW segment from `n` vectors of `dim=2`, where vector `i` + /// is `[i as f32, 0.0]`. Returns the collection with one sealed segment. + fn make_sealed_collection(n: usize) -> VectorCollection { + let mut coll = VectorCollection::new( + 2, + HnswParams { + metric: DistanceMetric::L2, + ..HnswParams::default() + }, + ); + for i in 0..n { + coll.insert(vec![i as f32, 0.0]); + } + let req = coll.seal("seg").unwrap(); + let mut idx = HnswIndex::new(req.dim, req.params); + for v in &req.vectors { + idx.insert(v.clone()).unwrap(); + } + coll.complete_build(req.segment_id, idx); + coll + } + + /// Attach SQ8 quantization to the first sealed segment of `coll`. + fn attach_sq8(coll: &mut VectorCollection) { + use crate::quantize::sq8::Sq8Codec; + + let sealed = &mut coll.sealed[0]; + let dim = sealed.index.dim(); + let n = sealed.index.len(); + let vecs: Vec> = (0..n) + .filter_map(|i| sealed.index.get_vector(i as u32).map(|v| v.to_vec())) + .collect(); + let refs: Vec<&[f32]> = vecs.iter().map(|v| v.as_slice()).collect(); + let codec = Sq8Codec::calibrate(&refs, dim); + let sq8_data: Vec = vecs.iter().flat_map(|v| codec.quantize(v)).collect(); + sealed.sq8 = Some((codec, sq8_data)); + } + + #[test] + fn sq8_search_returns_correct_nearest_neighbor() { + let mut coll = make_sealed_collection(200); + attach_sq8(&mut coll); + + let results = coll.search(&[100.0, 0.0], 5, 64); + assert!(!results.is_empty(), "expected non-empty results"); + assert_eq!( + results[0].id, 100, + "nearest neighbor of [100,0] should be id=100, got id={}", + results[0].id + ); + } + + #[test] + fn sq8_search_recall_matches_hnsw() { + // Build two identical collections — one without SQ8, one with. + let coll_plain = make_sealed_collection(500); + let mut coll_sq8 = make_sealed_collection(500); + attach_sq8(&mut coll_sq8); + + let query = [250.0f32, 0.0]; + let top_k = 5; + + let plain_results = coll_plain.search(&query, top_k, 64); + let sq8_results = coll_sq8.search(&query, top_k, 64); + + let plain_ids: std::collections::HashSet = + plain_results.iter().map(|r| r.id).collect(); + let sq8_ids: std::collections::HashSet = sq8_results.iter().map(|r| r.id).collect(); + + let overlap = plain_ids.intersection(&sq8_ids).count(); + assert!( + overlap >= 4, + "SQ8 recall too low: {overlap}/5 results matched plain HNSW (need >=4)" + ); + } + + #[test] + fn sq8_search_does_not_scan_all_vectors() { + // This test validates correctness of the SQ8 search path for a large + // segment. The bug being guarded against is an O(N) linear scan instead + // of graph-guided traversal: the fix must use HNSW with SQ8 as the + // distance function. Correctness (correct nearest neighbor) is the + // invariant that must be preserved when the implementation changes. + let mut coll = make_sealed_collection(2000); + attach_sq8(&mut coll); + + let results = coll.search(&[1000.0, 0.0], 5, 64); + assert!(!results.is_empty(), "expected non-empty results"); + assert_eq!( + results[0].id, 1000, + "nearest neighbor of [1000,0] should be id=1000, got id={}", + results[0].id + ); + } } diff --git a/nodedb-vector/src/hnsw/search.rs b/nodedb-vector/src/hnsw/search.rs index d59a5e22..c917943f 100644 --- a/nodedb-vector/src/hnsw/search.rs +++ b/nodedb-vector/src/hnsw/search.rs @@ -21,7 +21,9 @@ impl HnswIndex { return Vec::new(); } - let ef = ef.max(k); + /// Maximum beam width to prevent runaway search cost. + const MAX_EF: usize = 8192; + let ef = ef.max(k).min(MAX_EF); let Some(ep) = self.entry_point else { return Vec::new(); }; diff --git a/nodedb-vector/src/mmap_segment.rs b/nodedb-vector/src/mmap_segment.rs index ef2d235f..cf37f4e7 100644 --- a/nodedb-vector/src/mmap_segment.rs +++ b/nodedb-vector/src/mmap_segment.rs @@ -92,7 +92,35 @@ impl MmapVectorSegment { u32::from_le(*ptr) as usize }; - let expected = HEADER_SIZE + count * dim * 4; + // Reject dim=0 with nonzero count: get_vector would compute offset=HEADER_SIZE + // for every ID, aliasing header bytes as vector data. + if dim == 0 && count > 0 { + unsafe { + libc::munmap(base as *mut libc::c_void, file_size); + } + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "mmap segment has dim=0 with nonzero count", + )); + } + + // Use checked arithmetic to prevent usize overflow on crafted headers. + let data_bytes = dim + .checked_mul(count) + .and_then(|dc| dc.checked_mul(4)) + .and_then(|bytes| bytes.checked_add(HEADER_SIZE)); + let expected = match data_bytes { + Some(v) => v, + None => { + unsafe { + libc::munmap(base as *mut libc::c_void, file_size); + } + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("mmap segment header overflow: dim={dim}, count={count}"), + )); + } + }; if file_size < expected { unsafe { libc::munmap(base as *mut libc::c_void, file_size); @@ -121,7 +149,12 @@ impl MmapVectorSegment { if idx >= self.count { return None; } - let offset = self.data_offset + idx * self.dim * 4; + let byte_len = self.dim.checked_mul(4)?; + let offset = self.data_offset.checked_add(idx.checked_mul(byte_len)?)?; + let end = offset.checked_add(byte_len)?; + if end > self.mmap_size { + return None; + } unsafe { let ptr = self.base.add(offset) as *const f32; Some(std::slice::from_raw_parts(ptr, self.dim)) @@ -134,9 +167,19 @@ impl MmapVectorSegment { if idx >= self.count { return; } - let offset = self.data_offset + idx * self.dim * 4; + let byte_len = match self.dim.checked_mul(4) { + Some(v) => v, + None => return, + }; + let offset = match self + .data_offset + .checked_add(idx.checked_mul(byte_len).unwrap_or(usize::MAX)) + { + Some(v) if v.checked_add(byte_len).is_some_and(|e| e <= self.mmap_size) => v, + _ => return, + }; let page_start = offset & !(4095); - let len = (self.dim * 4 + 4095) & !(4095); + let len = (byte_len + 4095) & !(4095); unsafe { libc::madvise( self.base.add(page_start) as *mut libc::c_void, @@ -249,4 +292,99 @@ mod tests { assert_eq!(seg.count(), 0); assert!(seg.get_vector(0).is_none()); } + + #[test] + fn overflow_dim_count_rejected() { + use std::io::Write; + + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("overflow.vseg"); + + // dim=0x40000001, count=0x40000001: count * dim * 4 overflows usize on 64-bit + // (0x40000001 * 0x40000001 * 4 = 0x4000000280000004, which wraps to a small value). + let dim: u32 = 0x40000001; + let count: u32 = 0x40000001; + + let mut f = std::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(&path) + .unwrap(); + f.write_all(&dim.to_le_bytes()).unwrap(); + f.write_all(&count.to_le_bytes()).unwrap(); + // No actual vector data — just a 8-byte header. + drop(f); + + let result = MmapVectorSegment::open(&path); + assert!( + result.is_err(), + "expected Err for overflow-inducing dim/count, got Ok" + ); + } + + #[test] + fn truncated_file_rejected() { + use std::io::Write; + + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("truncated.vseg"); + + // Header claims dim=3, count=100 but only 8 bytes of actual data. + let dim: u32 = 3; + let count: u32 = 100; + + let mut f = std::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(&path) + .unwrap(); + f.write_all(&dim.to_le_bytes()).unwrap(); + f.write_all(&count.to_le_bytes()).unwrap(); + drop(f); + + let result = MmapVectorSegment::open(&path); + match result { + Err(e) => assert_eq!( + e.kind(), + std::io::ErrorKind::InvalidData, + "expected InvalidData, got {:?}", + e.kind() + ), + Ok(_) => panic!("expected Err for truncated file, got Ok"), + } + } + + #[test] + fn zero_dim_with_nonzero_count_rejected() { + use std::io::Write; + + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("zerodim.vseg"); + + // dim=0, count=1000: expected size = HEADER_SIZE + 0 = 8, so the size + // check passes, but get_vector would read header bytes as vector data. + // dim=0 must be rejected outright. + let dim: u32 = 0; + let count: u32 = 1000; + + let mut f = std::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(&path) + .unwrap(); + f.write_all(&dim.to_le_bytes()).unwrap(); + f.write_all(&count.to_le_bytes()).unwrap(); + // Write enough padding so the file passes a naive size check. + f.write_all(&[0u8; 64]).unwrap(); + drop(f); + + let result = MmapVectorSegment::open(&path); + assert!( + result.is_err(), + "expected Err for dim=0 with nonzero count, got Ok" + ); + } } From 9eeb5eb083a0f1351d250d90414a3249976f6e4d Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Thu, 16 Apr 2026 19:02:30 +0800 Subject: [PATCH 04/10] refactor(columnar): expand WAL record types and fix mutation edge cases Add richer WAL record variants for columnar operations to support finer- grained replay. Fix mutation handling for edge cases in segment application. --- nodedb-columnar/src/mutation.rs | 9 +- nodedb-columnar/src/wal_record.rs | 206 +++++++++++++++++++++++------- 2 files changed, 170 insertions(+), 45 deletions(-) diff --git a/nodedb-columnar/src/mutation.rs b/nodedb-columnar/src/mutation.rs index 4f08d904..15f90b9e 100644 --- a/nodedb-columnar/src/mutation.rs +++ b/nodedb-columnar/src/mutation.rs @@ -85,7 +85,7 @@ impl MutationEngine { } // Generate WAL record BEFORE applying the mutation. - let row_data = encode_row_for_wal(values); + let row_data = encode_row_for_wal(values)?; let wal = ColumnarWalRecord::InsertRow { collection: self.collection.clone(), row_data, @@ -310,6 +310,13 @@ impl MutationEngine { self.memtable.get_row(row_idx) } + /// The segment ID that will be assigned to the next flushed segment. + /// + /// Use this to obtain the ID to pass to `on_memtable_flushed`. + pub fn next_segment_id(&self) -> u32 { + self.next_segment_id + } + /// Whether a segment should be compacted based on its delete ratio. pub fn should_compact(&self, segment_id: u32, total_rows: u64) -> bool { self.delete_bitmaps diff --git a/nodedb-columnar/src/wal_record.rs b/nodedb-columnar/src/wal_record.rs index 0a00006b..a8462242 100644 --- a/nodedb-columnar/src/wal_record.rs +++ b/nodedb-columnar/src/wal_record.rs @@ -91,7 +91,9 @@ impl ColumnarWalRecord { /// Each value is written as: [type_tag: u8][value_bytes]. /// This is more compact than MessagePack for typed columns and enables /// direct replay into the memtable without schema interpretation overhead. -pub fn encode_row_for_wal(values: &[nodedb_types::value::Value]) -> Vec { +pub fn encode_row_for_wal( + values: &[nodedb_types::value::Value], +) -> Result, crate::error::ColumnarError> { use nodedb_types::value::Value; let mut buf = Vec::with_capacity(values.len() * 10); // Rough estimate. @@ -152,14 +154,63 @@ pub fn encode_row_for_wal(values: &[nodedb_types::value::Value]) -> Vec { _ => { // Geometry and other complex types: serialize as JSON bytes. buf.push(10); - let json = sonic_rs::to_vec(value).unwrap_or_default(); + let json = sonic_rs::to_vec(value).map_err(|e| { + crate::error::ColumnarError::Serialization(format!( + "failed to serialize value as JSON: {e}" + )) + })?; buf.extend_from_slice(&(json.len() as u32).to_le_bytes()); buf.extend_from_slice(&json); } } } - buf + Ok(buf) +} + +/// Maximum length for a variable-length field in a WAL record (256 MiB). +/// Prevents OOM from crafted/corrupt records with bogus length prefixes. +const MAX_FIELD_LEN: usize = 256 * 1024 * 1024; + +/// Read exactly `n` bytes from `data` at `cursor`, advancing cursor. +/// Returns `Err` if not enough bytes remain. +fn read_slice<'a>( + data: &'a [u8], + cursor: &mut usize, + n: usize, + context: &str, +) -> Result<&'a [u8], crate::error::ColumnarError> { + let end = cursor.checked_add(n).ok_or_else(|| { + crate::error::ColumnarError::Serialization(format!("overflow in {context}")) + })?; + if end > data.len() { + return Err(crate::error::ColumnarError::Serialization(format!( + "truncated {context}: need {n} bytes at offset {cursor}, have {}", + data.len().saturating_sub(*cursor) + ))); + } + let slice = &data[*cursor..end]; + *cursor = end; + Ok(slice) +} + +/// Read a u32 length prefix, validate it against MAX_FIELD_LEN, then read +/// that many bytes. Returns the payload slice. +fn read_length_prefixed<'a>( + data: &'a [u8], + cursor: &mut usize, + context: &str, +) -> Result<&'a [u8], crate::error::ColumnarError> { + let len_bytes = read_slice(data, cursor, 4, context)?; + let len = u32::from_le_bytes(len_bytes.try_into().map_err(|_| { + crate::error::ColumnarError::Serialization(format!("truncated {context} len")) + })?) as usize; + if len > MAX_FIELD_LEN { + return Err(crate::error::ColumnarError::Serialization(format!( + "{context} length {len} exceeds maximum {MAX_FIELD_LEN}" + ))); + } + read_slice(data, cursor, len, context) } /// Decode a row from the columnar wire format back into Values. @@ -172,37 +223,40 @@ pub fn decode_row_from_wal( let mut cursor = 0; while cursor < data.len() { - let tag = data[cursor]; - cursor += 1; + let tag_slice = read_slice(data, &mut cursor, 1, "tag")?; + let tag = tag_slice[0]; let value = match tag { 0 => Value::Null, 1 => { - let v = i64::from_le_bytes(data[cursor..cursor + 8].try_into().map_err(|_| { + let bytes = read_slice(data, &mut cursor, 8, "i64")?; + let v = i64::from_le_bytes(bytes.try_into().map_err(|_| { crate::error::ColumnarError::Serialization("truncated i64".into()) })?); - cursor += 8; Value::Integer(v) } 2 => { - let v = f64::from_le_bytes(data[cursor..cursor + 8].try_into().map_err(|_| { + let bytes = read_slice(data, &mut cursor, 8, "f64")?; + let v = f64::from_le_bytes(bytes.try_into().map_err(|_| { crate::error::ColumnarError::Serialization("truncated f64".into()) })?); - cursor += 8; Value::Float(v) } 3 => { - let v = data[cursor] != 0; - cursor += 1; - Value::Bool(v) + let bytes = read_slice(data, &mut cursor, 1, "bool")?; + Value::Bool(bytes[0] != 0) } 4 | 5 | 8 => { - let len = u32::from_le_bytes(data[cursor..cursor + 4].try_into().map_err(|_| { - crate::error::ColumnarError::Serialization("truncated len".into()) - })?) as usize; - cursor += 4; - let bytes = &data[cursor..cursor + len]; - cursor += len; + let bytes = read_length_prefixed( + data, + &mut cursor, + match tag { + 4 => "string", + 5 => "bytes", + 8 => "uuid", + _ => unreachable!(), + }, + )?; match tag { 4 => Value::String(String::from_utf8_lossy(bytes).into_owned()), 5 => Value::Bytes(bytes.to_vec()), @@ -211,43 +265,41 @@ pub fn decode_row_from_wal( } } 6 => { - let micros = - i64::from_le_bytes(data[cursor..cursor + 8].try_into().map_err(|_| { - crate::error::ColumnarError::Serialization("truncated timestamp".into()) - })?); - cursor += 8; + let bytes = read_slice(data, &mut cursor, 8, "timestamp")?; + let micros = i64::from_le_bytes(bytes.try_into().map_err(|_| { + crate::error::ColumnarError::Serialization("truncated timestamp".into()) + })?); Value::DateTime(nodedb_types::datetime::NdbDateTime::from_micros(micros)) } 7 => { - let mut bytes = [0u8; 16]; - bytes.copy_from_slice(&data[cursor..cursor + 16]); - cursor += 16; - Value::Decimal(rust_decimal::Decimal::deserialize(bytes)) + let bytes = read_slice(data, &mut cursor, 16, "decimal")?; + let mut arr = [0u8; 16]; + arr.copy_from_slice(bytes); + Value::Decimal(rust_decimal::Decimal::deserialize(arr)) } 9 => { - let count = - u32::from_le_bytes(data[cursor..cursor + 4].try_into().map_err(|_| { - crate::error::ColumnarError::Serialization("truncated vector count".into()) - })?) as usize; - cursor += 4; + let count_bytes = read_slice(data, &mut cursor, 4, "vector count")?; + let count = u32::from_le_bytes(count_bytes.try_into().map_err(|_| { + crate::error::ColumnarError::Serialization("truncated vector count".into()) + })?) as usize; + if count > MAX_FIELD_LEN / 4 { + return Err(crate::error::ColumnarError::Serialization(format!( + "vector count {count} exceeds maximum {}", + MAX_FIELD_LEN / 4 + ))); + } let mut arr = Vec::with_capacity(count); for _ in 0..count { - let f = - f32::from_le_bytes(data[cursor..cursor + 4].try_into().map_err(|_| { - crate::error::ColumnarError::Serialization("truncated f32".into()) - })?); - cursor += 4; + let fb = read_slice(data, &mut cursor, 4, "vector f32")?; + let f = f32::from_le_bytes(fb.try_into().map_err(|_| { + crate::error::ColumnarError::Serialization("truncated f32".into()) + })?); arr.push(Value::Float(f as f64)); } Value::Array(arr) } 10 => { - let len = u32::from_le_bytes(data[cursor..cursor + 4].try_into().map_err(|_| { - crate::error::ColumnarError::Serialization("truncated json len".into()) - })?) as usize; - cursor += 4; - let json_bytes = &data[cursor..cursor + len]; - cursor += len; + let json_bytes = read_length_prefixed(data, &mut cursor, "json")?; sonic_rs::from_slice(json_bytes).unwrap_or(Value::Null) } _ => { @@ -316,7 +368,7 @@ mod tests { Value::Array(vec![Value::Float(1.0), Value::Float(2.0)]), ]; - let encoded = encode_row_for_wal(&values); + let encoded = encode_row_for_wal(&values).expect("encode"); let decoded = decode_row_from_wal(&encoded).expect("decode"); assert_eq!(decoded.len(), values.len()); @@ -335,4 +387,70 @@ mod tests { ); assert_eq!(decoded[8], Value::Null); } + + #[test] + fn decode_truncated_i64_returns_error() { + // Tag 1 (i64) requires 8 payload bytes; supply none. + // Today the slice index `data[cursor..cursor+8]` panics with an index + // out-of-bounds. After the fix, `try_into()` returns the + // Serialization error instead. + let result = decode_row_from_wal(&[1]); + assert!( + result.is_err(), + "truncated i64 payload must return Err, not panic" + ); + } + + #[test] + fn decode_truncated_string_returns_error() { + // Tag 4 (string): length prefix says 255 bytes but the slice ends + // immediately after the 4-byte length field. The read of + // `data[cursor..cursor+255]` panics today; after the fix it errors. + let input = { + let mut v = vec![4u8]; // tag = string + v.extend_from_slice(&255u32.to_le_bytes()); // len = 255 + // no payload bytes follow + v + }; + let result = decode_row_from_wal(&input); + assert!( + result.is_err(), + "truncated string payload must return Err, not panic" + ); + } + + #[test] + fn decode_huge_vector_count_returns_error() { + // Tag 9 (vector array): count = 0x7FFFFFFF. After reading the count, + // the very first iteration tries to read 4 bytes of f32 from an empty + // slice, which panics today. After the fix the loop errors out cleanly + // before any allocation proportional to count is attempted. + let input = { + let mut v = vec![9u8]; // tag = vector array + v.extend_from_slice(&0x7FFF_FFFFu32.to_le_bytes()); // count + // no f32 bytes follow + v + }; + let result = decode_row_from_wal(&input); + assert!( + result.is_err(), + "huge vector count with no payload must return Err, not panic or OOM" + ); + } + + #[test] + fn decode_truncated_decimal_returns_error() { + // Tag 7 (Decimal) requires 16 bytes; supply only 4. + // `data[cursor..cursor+16]` panics today; after the fix it errors. + let input = { + let mut v = vec![7u8]; // tag = decimal + v.extend_from_slice(&[0u8; 4]); // only 4 bytes, need 16 + v + }; + let result = decode_row_from_wal(&input); + assert!( + result.is_err(), + "truncated decimal payload must return Err, not panic" + ); + } } From c406d0296a82ed75a73a1521b9333bdff198a2b6 Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Thu, 16 Apr 2026 19:02:40 +0800 Subject: [PATCH 05/10] perf(aggregate): replace document-materializing GROUP BY with streaming accumulators MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous implementation stored all matching raw document bytes grouped by key, then aggregated at the end — O(total_docs × avg_doc_size) memory. Replace with per-group streaming accumulators (accum.rs) that retain only the derived state needed for each aggregate function. Memory is now O(num_groups × num_aggregates) regardless of how many documents match. Supported functions: count, sum, avg, min, max, count_distinct, stddev/variance (Welford), approx_count_distinct (HLL), approx_percentile (t-digest), approx_topk (space-saving), array_agg, string_agg, percentile_cont. Array-materializing variants are capped at 10,000 items. Add aggregate_helpers.rs in nodedb-query for the field-extraction primitives used by the accumulator path. --- .../src/msgpack_scan/aggregate_helpers.rs | 78 ++++ nodedb-query/src/msgpack_scan/mod.rs | 1 + nodedb/src/data/executor/handlers/accum.rs | 373 ++++++++++++++++++ .../src/data/executor/handlers/aggregate.rs | 181 ++++----- nodedb/src/data/executor/handlers/mod.rs | 1 + 5 files changed, 526 insertions(+), 108 deletions(-) create mode 100644 nodedb-query/src/msgpack_scan/aggregate_helpers.rs create mode 100644 nodedb/src/data/executor/handlers/accum.rs diff --git a/nodedb-query/src/msgpack_scan/aggregate_helpers.rs b/nodedb-query/src/msgpack_scan/aggregate_helpers.rs new file mode 100644 index 00000000..10e33e61 --- /dev/null +++ b/nodedb-query/src/msgpack_scan/aggregate_helpers.rs @@ -0,0 +1,78 @@ +//! Public helpers for streaming aggregate accumulators. +//! +//! These thin wrappers expose field-extraction primitives used by the +//! `handlers/aggregate.rs` streaming accumulator path in the `nodedb` crate. +//! Each function operates on a single raw MessagePack document byte slice and +//! returns only the scalar value needed by the calling accumulator — no +//! document bytes are retained after the call returns. + +use nodedb_types::Value; + +use crate::expr::SqlExpr; +use crate::msgpack_scan::field::extract_field; +use crate::msgpack_scan::reader::{read_f64, read_str, read_value}; +use crate::value_ops; + +// ── Expression evaluator ─────────────────────────────────────────────────── + +#[inline] +fn eval_expr(doc: &[u8], expr: &SqlExpr) -> Option { + let doc_val = nodedb_types::json_msgpack::value_from_msgpack(doc).ok()?; + Some(expr.eval(&doc_val)) +} + +// ── Public extraction helpers ────────────────────────────────────────────── + +/// Extract a numeric (f64) value from `field`, or evaluate `expr` if provided. +/// Returns `None` when the field is absent or cannot be converted to f64. +#[inline] +pub fn extract_f64(doc: &[u8], field: &str, expr: Option<&SqlExpr>) -> Option { + if let Some(expr) = expr { + return value_ops::value_to_f64(&eval_expr(doc, expr)?, false); + } + let (start, _end) = extract_field(doc, 0, field)?; + read_f64(doc, start) +} + +/// Extract a display string from `field`, or evaluate `expr` if provided. +/// Returns `None` when the field is absent. +pub fn extract_str(doc: &[u8], field: &str, expr: Option<&SqlExpr>) -> Option { + if let Some(expr) = expr { + return Some(value_ops::value_to_display_string(&eval_expr(doc, expr)?)); + } + let (start, _end) = extract_field(doc, 0, field)?; + read_str(doc, start).map(|s| s.to_string()) +} + +/// Extract a field as `Value`. Uses direct msgpack→Value for scalars; +/// falls back to full document decode only for complex types. +pub fn extract_value(doc: &[u8], field: &str, expr: Option<&SqlExpr>) -> Option { + if let Some(expr) = expr { + return eval_expr(doc, expr); + } + let (start, end) = extract_field(doc, 0, field)?; + if let Some(v) = read_value(doc, start) { + return Some(v); + } + let field_bytes = &doc[start..end]; + nodedb_types::json_msgpack::value_from_msgpack(field_bytes).ok() +} + +/// Extract a field or expression result as raw msgpack bytes. +/// Used by `count_distinct`, `approx_count_distinct`, `approx_topk`, etc. +pub fn extract_bytes(doc: &[u8], field: &str, expr: Option<&SqlExpr>) -> Option> { + if let Some(expr) = expr { + let val = eval_expr(doc, expr)?; + return nodedb_types::json_msgpack::value_to_msgpack(&val).ok(); + } + let (start, end) = extract_field(doc, 0, field)?; + Some(doc[start..end].to_vec()) +} + +/// Returns `Some(())` when the field is present and non-null. +/// Used by `count(field)` accumulator to count non-null values. +#[inline] +pub fn extract_non_null(doc: &[u8], field: &str, expr: Option<&SqlExpr>) -> Option<()> { + let v = extract_value(doc, field, expr)?; + if v.is_null() { None } else { Some(()) } +} diff --git a/nodedb-query/src/msgpack_scan/mod.rs b/nodedb-query/src/msgpack_scan/mod.rs index 46b0f29e..6e5981c4 100644 --- a/nodedb-query/src/msgpack_scan/mod.rs +++ b/nodedb-query/src/msgpack_scan/mod.rs @@ -5,6 +5,7 @@ //! reads, comparisons, and hashing all work on raw byte offsets. pub mod aggregate; +pub mod aggregate_helpers; pub mod compare; pub mod field; pub mod filter; diff --git a/nodedb/src/data/executor/handlers/accum.rs b/nodedb/src/data/executor/handlers/accum.rs new file mode 100644 index 00000000..3809bf50 --- /dev/null +++ b/nodedb/src/data/executor/handlers/accum.rs @@ -0,0 +1,373 @@ +//! Streaming aggregate accumulators for the generic GROUP BY path. +//! +//! Each `AggAccum` variant holds only the derived state needed to compute the +//! final aggregate result — no raw document bytes are retained. Memory per +//! group is O(num_aggregates × accumulator_size) regardless of how many +//! documents match that group. + +use std::collections::HashSet; + +use crate::bridge::physical_plan::AggregateSpec; +use nodedb_types::Value; + +/// Maximum items collected by materializing aggregates (`array_agg`, +/// `array_agg_distinct`, `percentile_cont`, `string_agg`). +pub(super) const ARRAY_AGG_CAP: usize = 10_000; + +/// Per-(group, aggregate-spec) running accumulator. +pub(super) enum AggAccum { + /// count(*) or count(field). + Count { n: u64 }, + /// sum / avg: Kahan-compensated running sum + count. + SumAvg { sum: f64, comp: f64, n: u64 }, + /// min. + Min { best: Option }, + /// max. + Max { best: Option }, + /// count_distinct: set of raw msgpack bytes. + CountDistinct { seen: HashSet> }, + /// stddev / variance variants: Welford M2 accumulator. + Welford { n: u64, mean: f64, m2: f64 }, + /// approx_count_distinct: HyperLogLog. + Hll { + hll: nodedb_types::approx::HyperLogLog, + }, + /// approx_percentile: t-digest. + TDigest { + digest: nodedb_types::approx::TDigest, + }, + /// approx_topk: space-saving. + TopK { + ss: nodedb_types::approx::SpaceSaving, + k: usize, + }, + /// array_agg (capped). + ArrayAgg { values: Vec }, + /// array_agg_distinct (capped). + ArrayAggDistinct { + seen: HashSet>, + values: Vec, + }, + /// percentile_cont (capped). + PercentileCont { values: Vec, pct: f64 }, + /// string_agg / group_concat (capped). + StringAgg { parts: Vec }, +} + +impl AggAccum { + pub(super) fn new(agg: &AggregateSpec) -> Self { + match agg.function.as_str() { + "count" => AggAccum::Count { n: 0 }, + "sum" | "avg" => AggAccum::SumAvg { + sum: 0.0, + comp: 0.0, + n: 0, + }, + "min" => AggAccum::Min { best: None }, + "max" => AggAccum::Max { best: None }, + "count_distinct" => AggAccum::CountDistinct { + seen: HashSet::new(), + }, + "stddev" | "stddev_pop" | "stddev_samp" | "variance" | "var_pop" | "var_samp" => { + AggAccum::Welford { + n: 0, + mean: 0.0, + m2: 0.0, + } + } + "approx_count_distinct" => AggAccum::Hll { + hll: nodedb_types::approx::HyperLogLog::new(), + }, + "approx_percentile" => AggAccum::TDigest { + digest: nodedb_types::approx::TDigest::new(), + }, + "approx_topk" => { + let k: usize = agg + .field + .find(':') + .and_then(|i| agg.field[..i].parse().ok()) + .unwrap_or(10); + AggAccum::TopK { + ss: nodedb_types::approx::SpaceSaving::new(k), + k, + } + } + "array_agg" => AggAccum::ArrayAgg { values: Vec::new() }, + "array_agg_distinct" => AggAccum::ArrayAggDistinct { + seen: HashSet::new(), + values: Vec::new(), + }, + "percentile_cont" => { + let pct = agg + .field + .find(':') + .and_then(|i| agg.field[..i].parse().ok()) + .unwrap_or(0.5); + AggAccum::PercentileCont { + values: Vec::new(), + pct, + } + } + "string_agg" | "group_concat" => AggAccum::StringAgg { parts: Vec::new() }, + _ => AggAccum::Count { n: 0 }, + } + } + + /// Feed one document into this accumulator. + pub(super) fn feed(&mut self, agg: &AggregateSpec, doc: &[u8]) { + use nodedb_query::msgpack_scan::aggregate_helpers as ah; + match self { + AggAccum::Count { n } => { + if agg.field == "*" && agg.expr.is_none() { + *n += 1; + } else if ah::extract_non_null(doc, &agg.field, agg.expr.as_ref()).is_some() { + *n += 1; + } + } + AggAccum::SumAvg { sum, comp, n } => { + if let Some(v) = ah::extract_f64(doc, &agg.field, agg.expr.as_ref()) { + let y = v - *comp; + let t = *sum + y; + *comp = (t - *sum) - y; + *sum = t; + *n += 1; + } + } + AggAccum::Min { best } => { + if let Some(v) = ah::extract_value(doc, &agg.field, agg.expr.as_ref()) { + if v.is_null() { + return; + } + let replace = match best { + None => true, + Some(cur) => { + nodedb_query::value_ops::compare_values(&v, cur) + == std::cmp::Ordering::Less + } + }; + if replace { + *best = Some(v); + } + } + } + AggAccum::Max { best } => { + if let Some(v) = ah::extract_value(doc, &agg.field, agg.expr.as_ref()) { + if v.is_null() { + return; + } + let replace = match best { + None => true, + Some(cur) => { + nodedb_query::value_ops::compare_values(&v, cur) + == std::cmp::Ordering::Greater + } + }; + if replace { + *best = Some(v); + } + } + } + AggAccum::CountDistinct { seen } => { + if let Some(bytes) = ah::extract_bytes(doc, &agg.field, agg.expr.as_ref()) { + if bytes != [0xc0u8] { + seen.insert(bytes); + } + } + } + AggAccum::Welford { n, mean, m2 } => { + if let Some(v) = ah::extract_f64(doc, &agg.field, agg.expr.as_ref()) { + *n += 1; + let delta = v - *mean; + *mean += delta / *n as f64; + let delta2 = v - *mean; + *m2 += delta * delta2; + } + } + AggAccum::Hll { hll } => { + if let Some(bytes) = ah::extract_bytes(doc, &agg.field, agg.expr.as_ref()) { + if bytes != [0xc0u8] { + hll.add(fnv1a(&bytes)); + } + } + } + AggAccum::TDigest { digest } => { + let actual = field_after_colon(&agg.field); + if let Some(v) = ah::extract_f64(doc, actual, agg.expr.as_ref()) { + digest.add(v); + } + } + AggAccum::TopK { ss, .. } => { + let actual = field_after_colon(&agg.field); + if let Some(bytes) = ah::extract_bytes(doc, actual, agg.expr.as_ref()) { + if bytes != [0xc0u8] { + ss.add(fnv1a(&bytes)); + } + } + } + AggAccum::ArrayAgg { values } => { + if values.len() < ARRAY_AGG_CAP { + if let Some(v) = ah::extract_value(doc, &agg.field, agg.expr.as_ref()) { + if !v.is_null() { + values.push(v); + } + } + } + } + AggAccum::ArrayAggDistinct { seen, values } => { + if values.len() < ARRAY_AGG_CAP { + if let Some(bytes) = ah::extract_bytes(doc, &agg.field, agg.expr.as_ref()) { + if bytes != [0xc0u8] && seen.insert(bytes) { + if let Some(v) = ah::extract_value(doc, &agg.field, agg.expr.as_ref()) { + values.push(v); + } + } + } + } + } + AggAccum::PercentileCont { values, .. } => { + if values.len() < ARRAY_AGG_CAP { + let actual = field_after_colon(&agg.field); + if let Some(v) = ah::extract_f64(doc, actual, agg.expr.as_ref()) { + values.push(v); + } + } + } + AggAccum::StringAgg { parts } => { + if parts.len() < ARRAY_AGG_CAP { + if let Some(s) = ah::extract_str(doc, &agg.field, agg.expr.as_ref()) { + parts.push(s); + } + } + } + } + } + + /// Consume the accumulator and produce the final `Value`. + pub(super) fn finalize(self, agg: &AggregateSpec) -> Value { + match self { + AggAccum::Count { n } => Value::Integer(n as i64), + AggAccum::SumAvg { sum, n, .. } => { + if agg.function == "avg" { + if n == 0 { + Value::Null + } else { + Value::Float(sum / n as f64) + } + } else { + Value::Float(sum) + } + } + AggAccum::Min { best } => best.unwrap_or(Value::Null), + AggAccum::Max { best } => best.unwrap_or(Value::Null), + AggAccum::CountDistinct { seen } => Value::Integer(seen.len() as i64), + AggAccum::Welford { n, mean: _, m2 } => { + if n < 2 { + return Value::Null; + } + let population = matches!( + agg.function.as_str(), + "stddev" | "stddev_pop" | "variance" | "var_pop" + ); + let divisor = if population { n as f64 } else { (n - 1) as f64 }; + let variance = m2 / divisor; + let result = if agg.function.contains("stddev") { + variance.sqrt() + } else { + variance + }; + Value::Float(result) + } + AggAccum::Hll { hll } => Value::Integer(hll.estimate().round() as i64), + AggAccum::TDigest { digest } => { + let pct = agg + .field + .find(':') + .and_then(|i| agg.field[..i].parse().ok()) + .unwrap_or(0.5); + let r = digest.quantile(pct); + if r.is_nan() { + Value::Null + } else { + Value::Float(r) + } + } + AggAccum::TopK { ss, k } => { + let arr: Vec = ss + .top_k() + .into_iter() + .take(k) + .map(|(item, count, error)| { + Value::Object( + [ + ("item".to_string(), Value::Integer(item as i64)), + ("count".to_string(), Value::Integer(count as i64)), + ("error".to_string(), Value::Integer(error as i64)), + ] + .into_iter() + .collect(), + ) + }) + .collect(); + Value::Array(arr) + } + AggAccum::ArrayAgg { values } => Value::Array(values), + AggAccum::ArrayAggDistinct { values, .. } => Value::Array(values), + AggAccum::PercentileCont { mut values, pct } => { + if values.is_empty() { + return Value::Null; + } + values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + let idx = (pct * (values.len() - 1) as f64).clamp(0.0, (values.len() - 1) as f64); + let lo = idx.floor() as usize; + let hi = idx.ceil() as usize; + let frac = idx - lo as f64; + Value::Float(values[lo] * (1.0 - frac) + values[hi] * frac) + } + AggAccum::StringAgg { parts } => Value::String(parts.join(",")), + } + } +} + +/// Per-group running state: one `AggAccum` per aggregate spec. +pub(super) struct GroupState { + pub(super) accums: Vec, +} + +impl GroupState { + pub(super) fn new(aggregates: &[AggregateSpec]) -> Self { + Self { + accums: aggregates.iter().map(AggAccum::new).collect(), + } + } + + pub(super) fn feed(&mut self, aggregates: &[AggregateSpec], doc: &[u8]) { + for (accum, agg) in self.accums.iter_mut().zip(aggregates) { + accum.feed(agg, doc); + } + } + + pub(super) fn finalize(self, aggregates: &[AggregateSpec]) -> Vec<(String, Value)> { + self.accums + .into_iter() + .zip(aggregates) + .map(|(accum, agg)| (agg.alias.clone(), accum.finalize(agg))) + .collect() + } +} + +/// FNV-1a hash (matches the implementation in nodedb-query aggregate.rs). +#[inline] +pub(super) fn fnv1a(bytes: &[u8]) -> u64 { + let mut h: u64 = 0xcbf29ce484222325; + for &b in bytes { + h ^= b as u64; + h = h.wrapping_mul(0x100000001b3); + } + h +} + +/// Extract the actual field name from "prefix:field" format (e.g. "0.95:latency"). +#[inline] +pub(super) fn field_after_colon(field: &str) -> &str { + field.find(':').map(|i| &field[i + 1..]).unwrap_or(field) +} diff --git a/nodedb/src/data/executor/handlers/aggregate.rs b/nodedb/src/data/executor/handlers/aggregate.rs index e815c995..e677ef50 100644 --- a/nodedb/src/data/executor/handlers/aggregate.rs +++ b/nodedb/src/data/executor/handlers/aggregate.rs @@ -1,8 +1,17 @@ //! Aggregate handler: GROUP BY, HAVING, and aggregate function execution. +//! +//! The generic (non-columnar) path uses **streaming accumulators** — see +//! `accum.rs`. Raw document bytes are never stored; only the extracted +//! scalar / approximate values needed by each aggregate function are kept. +//! Memory is O(num_groups × num_aggregates) instead of +//! O(total_matching_docs × avg_doc_size). + +use std::collections::HashMap; use sonic_rs; use tracing::debug; +use super::accum::GroupState; use crate::bridge::envelope::{ErrorCode, Response}; use crate::bridge::physical_plan::AggregateSpec; use crate::bridge::scan_filter::ScanFilter; @@ -11,10 +20,8 @@ use crate::data::executor::task::ExecutionTask; use nodedb_query::agg_key::canonical_agg_key; use nodedb_query::msgpack_scan; -/// Build a cache key for an aggregate query. -/// -/// Format: `"{tid}:{collection}\0{group_fields}\0{agg_ops}"`. -/// Null bytes separate sections to avoid ambiguity with field names. +// ── Cache key ────────────────────────────────────────────────────────────── + fn aggregate_cache_key( tid: u32, collection: &str, @@ -60,37 +67,6 @@ fn aggregate_cache_key( key } -/// Group a single document into the binary_groups map. -/// -/// Applies filter predicates, computes group key, and stores the raw -/// document bytes for later aggregation. -fn group_doc( - value: &[u8], - group_by: &[String], - filter_predicates: &[ScanFilter], - use_field_index: bool, - binary_groups: &mut std::collections::HashMap>>, -) { - if use_field_index { - let idx = msgpack_scan::FieldIndex::build(value, 0) - .unwrap_or_else(msgpack_scan::FieldIndex::empty); - if !filter_predicates - .iter() - .all(|f| f.matches_binary_indexed(value, &idx)) - { - return; - } - let key = msgpack_scan::group_key::build_group_key_indexed(value, group_by, &idx); - binary_groups.entry(key).or_default().push(value.to_vec()); - } else { - if !filter_predicates.iter().all(|f| f.matches_binary(value)) { - return; - } - let key = msgpack_scan::build_group_key(value, group_by); - binary_groups.entry(key).or_default().push(value.to_vec()); - } -} - fn legacy_aggregate_pairs(aggregates: &[AggregateSpec]) -> Option> { aggregates .iter() @@ -130,6 +106,8 @@ fn apply_user_aliases_to_rows(rows: &mut [serde_json::Value], aggregates: &[Aggr } } +// ── CoreLoop impl ────────────────────────────────────────────────────────── + impl CoreLoop { #[allow(clippy::too_many_arguments)] pub(in crate::data::executor) fn execute_aggregate( @@ -148,8 +126,6 @@ impl CoreLoop { debug!(core = self.core_id, %collection, group_fields = group_by.len(), aggs = aggregates.len(), "aggregate"); // Fast path: incremental aggregate cache. - // If we've cached the result for this exact (collection, group_by, aggregates) - // combination and there are no filters/having, return cached result directly. if filters.is_empty() && having.is_empty() { let cache_key = aggregate_cache_key( tid, @@ -166,9 +142,6 @@ impl CoreLoop { } // Fast path: index-backed COUNT/GROUP BY. - // When GROUP BY has a single field, no filters, no HAVING, and the - // only aggregate is COUNT(*), scan the INDEXES table directly. - // No document table access — O(index_entries) instead of O(documents). if group_by.len() == 1 && filters.is_empty() && having.is_empty() @@ -177,12 +150,9 @@ impl CoreLoop { && aggregates[0].function == "count" { let field = &group_by[0]; - // Empty index — fall through to full scan (documents may exist - // without index entries if no secondary indexes are declared). if let Ok(groups) = self.sparse.scan_index_groups(tid, collection, field) && !groups.is_empty() { - // Build result rows as raw msgpack — no serde_json::Value. let mut payload_buf = Vec::with_capacity(groups.len() * 64); let row_count = groups.len().min(limit); let count_key = aggregates[0] @@ -195,8 +165,7 @@ impl CoreLoop { msgpack_scan::write_kv_str(&mut payload_buf, field, &value); msgpack_scan::write_kv_i64(&mut payload_buf, &count_key, count as i64); } - let results_payload = payload_buf; - return match Ok::, crate::Error>(results_payload) { + return match Ok::, crate::Error>(payload_buf) { Ok(payload) => self.response_with_payload(task, payload), Err(e) => self.response_error( task, @@ -208,22 +177,14 @@ impl CoreLoop { } } - // Aggregates must scan all matching documents for correct results. - // Cap at aggregate_scan_cap to prevent OOM on unbounded collections. let scan_limit = self.query_tuning.aggregate_scan_cap; - // If collection has columnar memtable data, read from there. - // Works for all columnar profiles: plain, timeseries, spatial. - // Spatial inserts write to both sparse (R-tree) and columnar (scans/aggregates). let columnar_mt = self .columnar_memtables .get(collection) .filter(|mt| !mt.is_empty()); // Fast path: native columnar aggregation. - // Groups directly on symbol IDs (u32) instead of JSON-serialized strings. - // Accumulates in-place without document construction. - // Falls back to generic path for complex filters (OR, string comparisons). if let Some(mt) = columnar_mt.filter(|_| sub_group_by.is_empty() && sub_aggregates.is_empty()) { @@ -250,7 +211,6 @@ impl CoreLoop { scan_limit, ) }) { - // Apply HAVING filters. if !having.is_empty() { let having_predicates: Vec = match zerompk::from_msgpack(having) { Ok(h) => h, @@ -295,13 +255,13 @@ impl CoreLoop { ), }; } - // Native path returned None — fall through to generic path. } - // ── Streaming aggregation: process documents in chunks ── - // Instead of loading all documents into memory, scan in chunks of - // 10K docs, group + aggregate each chunk, then merge partial results. - // Memory: O(chunk_size + num_groups) instead of O(all_docs). + // ── Streaming aggregation ────────────────────────────────────────── + // Documents are processed one at a time. Per-group accumulators hold + // only the derived scalar / approximate state needed for the final + // result — no raw document bytes are retained. + // Memory: O(num_groups × num_aggregates) instead of O(all_docs). let filter_predicates: Vec = if filters.is_empty() { Vec::new() @@ -316,28 +276,51 @@ impl CoreLoop { }; let use_field_index = filter_predicates.len() + group_by.len() >= 2; + let need_sub = !sub_group_by.is_empty() && !sub_aggregates.is_empty(); - // Accumulate per-group doc bytes across all chunks. - // Key: group_key string, Value: collected raw doc bytes for final aggregation. - let mut binary_groups: std::collections::HashMap>> = - std::collections::HashMap::new(); + // outer_group_key → GroupState + let mut groups: HashMap = HashMap::new(); + // outer_group_key → sub_group_key → GroupState + let mut sub_groups: HashMap> = HashMap::new(); let chunk_size = 10_000; - // Universal scan: routes to the correct engine (KV, columnar, sparse/strict) - // and normalizes all results to standard msgpack maps. let scan_result = self .scan_collection(tid, collection, scan_limit) .map(|docs| { for chunk in docs.chunks(chunk_size) { for (_, value) in chunk { - group_doc( - value, - group_by, - &filter_predicates, - use_field_index, - &mut binary_groups, - ); + let outer_key = if use_field_index { + let idx = msgpack_scan::FieldIndex::build(value, 0) + .unwrap_or_else(msgpack_scan::FieldIndex::empty); + if !filter_predicates + .iter() + .all(|f| f.matches_binary_indexed(value, &idx)) + { + continue; + } + msgpack_scan::group_key::build_group_key_indexed(value, group_by, &idx) + } else { + if !filter_predicates.iter().all(|f| f.matches_binary(value)) { + continue; + } + msgpack_scan::build_group_key(value, group_by) + }; + + groups + .entry(outer_key.clone()) + .or_insert_with(|| GroupState::new(aggregates)) + .feed(aggregates, value); + + if need_sub { + let sub_key = msgpack_scan::build_group_key(value, sub_group_by); + sub_groups + .entry(outer_key) + .or_default() + .entry(sub_key) + .or_insert_with(|| GroupState::new(sub_aggregates)) + .feed(sub_aggregates, value); + } } } }); @@ -345,12 +328,12 @@ impl CoreLoop { match scan_result { Ok(()) => { let mut results: Vec = Vec::new(); - for (group_key, group_docs) in &binary_groups { + + for (group_key, state) in groups { let mut row = serde_json::Map::new(); - // Insert group-by field values into the result row. if !group_by.is_empty() - && let Ok(parts) = sonic_rs::from_str::>(group_key) + && let Ok(parts) = sonic_rs::from_str::>(&group_key) { for (i, field) in group_by.iter().enumerate() { let val = parts.get(i).cloned().unwrap_or(serde_json::Value::Null); @@ -358,32 +341,18 @@ impl CoreLoop { } } - let doc_slices: Vec<&[u8]> = group_docs.iter().map(|d| d.as_slice()).collect(); - - for agg in aggregates { - let val = msgpack_scan::compute_aggregate_binary( - &agg.function, - &agg.field, - agg.expr.as_ref(), - &doc_slices, - ); + for (alias, val) in state.finalize(aggregates) { let json_val: serde_json::Value = val.into(); - row.insert(agg.alias.clone(), json_val); + row.insert(alias, json_val); } - // Nested sub-aggregation on raw bytes. - if !sub_group_by.is_empty() && !sub_aggregates.is_empty() { - let mut sub_groups: std::collections::HashMap> = - std::collections::HashMap::new(); - for doc_bytes in &doc_slices { - let sub_key = msgpack_scan::build_group_key(doc_bytes, sub_group_by); - sub_groups.entry(sub_key).or_default().push(doc_bytes); - } - - let mut sub_results = Vec::new(); - for (sub_key, sub_docs) in &sub_groups { + if need_sub { + let sub_map = sub_groups.remove(&group_key).unwrap_or_default(); + let mut sub_results: Vec = Vec::new(); + for (sub_key, sub_state) in sub_map { let mut sub_row = serde_json::Map::new(); - if let Ok(parts) = sonic_rs::from_str::>(sub_key) + if let Ok(parts) = + sonic_rs::from_str::>(&sub_key) { for (i, field) in sub_group_by.iter().enumerate() { let val = @@ -391,15 +360,9 @@ impl CoreLoop { sub_row.insert(field.clone(), val); } } - for agg in sub_aggregates { - let val = msgpack_scan::compute_aggregate_binary( - &agg.function, - &agg.field, - agg.expr.as_ref(), - sub_docs, - ); + for (alias, val) in sub_state.finalize(sub_aggregates) { let json_val: serde_json::Value = val.into(); - sub_row.insert(agg.alias.clone(), json_val); + sub_row.insert(alias, json_val); } let mut sub_value = serde_json::Value::Object(sub_row); apply_user_aliases_to_rows( @@ -421,7 +384,11 @@ impl CoreLoop { let having_predicates: Vec = match zerompk::from_msgpack(having) { Ok(f) => f, Err(e) => { - tracing::warn!(core = self.core_id, error = %e, "HAVING predicate deserialization failed (schemaless)"); + tracing::warn!( + core = self.core_id, + error = %e, + "HAVING predicate deserialization failed (schemaless)" + ); Vec::new() } }; @@ -438,7 +405,6 @@ impl CoreLoop { match super::super::response_codec::encode_json_vec(&results) { Ok(payload) => { - // Cache the result for future identical queries. if filters.is_empty() && having.is_empty() { let cache_key = aggregate_cache_key( tid, @@ -448,7 +414,6 @@ impl CoreLoop { sub_group_by, sub_aggregates, ); - // Bounded cache: max 256 entries per core. if self.aggregate_cache.len() < 256 { self.aggregate_cache.insert(cache_key, payload.clone()); } diff --git a/nodedb/src/data/executor/handlers/mod.rs b/nodedb/src/data/executor/handlers/mod.rs index 9cd56ab7..cfcd0fc8 100644 --- a/nodedb/src/data/executor/handlers/mod.rs +++ b/nodedb/src/data/executor/handlers/mod.rs @@ -1,3 +1,4 @@ +mod accum; pub mod aggregate; pub mod bulk_dml; pub mod columnar_agg; From 861f7a5ca94f04fd05d7951121e040e4795fd09a Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Thu, 16 Apr 2026 19:02:49 +0800 Subject: [PATCH 06/10] feat(columnar): flush memtable to in-memory segments and include in scans When the columnar memtable reaches the flush threshold, drain it into a compressed segment and retain the bytes in memory. Extend scan_normalize to read from flushed segments before falling back to the live memtable, so queries see all rows regardless of whether they have been flushed. This makes columnar scan results consistent across memtable boundaries without requiring a durable write path yet. --- nodedb/src/data/executor/core_loop/mod.rs | 8 + .../data/executor/handlers/columnar_write.rs | 34 ++++ nodedb/src/data/executor/scan_normalize.rs | 172 ++++++++++++++++-- 3 files changed, 197 insertions(+), 17 deletions(-) diff --git a/nodedb/src/data/executor/core_loop/mod.rs b/nodedb/src/data/executor/core_loop/mod.rs index 6232dd13..15b02d3c 100644 --- a/nodedb/src/data/executor/core_loop/mod.rs +++ b/nodedb/src/data/executor/core_loop/mod.rs @@ -155,6 +155,13 @@ pub struct CoreLoop { pub(in crate::data::executor) columnar_engines: HashMap, + /// Flushed columnar segment bytes, keyed by "{tid}:{collection}". + /// Each entry is a list of encoded segment buffers produced by `SegmentWriter`. + /// Kept in memory so `scan_columnar` can read rows that were drained from the + /// active memtable during a flush (otherwise those rows would be lost until a + /// real on-disk segment reader is wired up). + pub(in crate::data::executor) columnar_flushed_segments: HashMap>>, + /// Per-collection max WAL LSN that has been ingested into the memtable. /// Used by the WAL catch-up deduplication: if a catch-up record's LSN /// is <= this value, the Data Plane skips it (already ingested). @@ -283,6 +290,7 @@ impl CoreLoop { ), columnar_memtables: HashMap::new(), columnar_engines: HashMap::new(), + columnar_flushed_segments: HashMap::new(), ts_max_ingested_lsn: HashMap::new(), last_ts_ingest: None, ts_last_value_caches: HashMap::new(), diff --git a/nodedb/src/data/executor/handlers/columnar_write.rs b/nodedb/src/data/executor/handlers/columnar_write.rs index c6cc08da..4f31267f 100644 --- a/nodedb/src/data/executor/handlers/columnar_write.rs +++ b/nodedb/src/data/executor/handlers/columnar_write.rs @@ -98,6 +98,40 @@ impl CoreLoop { } } + // Flush memtable to a segment if the threshold has been reached. + if engine.should_flush() { + let new_segment_id = engine.next_segment_id(); + let (schema, columns, row_count) = engine.memtable_mut().drain_optimized(); + if row_count > 0 { + match nodedb_columnar::SegmentWriter::plain() + .write_segment(&schema, &columns, row_count) + { + Ok(bytes) => { + self.columnar_flushed_segments + .entry(collection.to_string()) + .or_default() + .push(bytes); + tracing::debug!( + core = self.core_id, + %collection, + new_segment_id, + row_count, + "columnar memtable flushed and segment bytes retained in memory" + ); + } + Err(e) => { + tracing::warn!( + core = self.core_id, + %collection, + error = %e, + "columnar segment encode failed; flushed rows may be lost" + ); + } + } + } + engine.on_memtable_flushed(new_segment_id); + } + // Populate R-tree for geometry columns so spatial predicates work. { let tid = task.request.tenant_id; diff --git a/nodedb/src/data/executor/scan_normalize.rs b/nodedb/src/data/executor/scan_normalize.rs index 91f18d88..b54da46b 100644 --- a/nodedb/src/data/executor/scan_normalize.rs +++ b/nodedb/src/data/executor/scan_normalize.rs @@ -108,26 +108,83 @@ impl CoreLoop { }; let schema = engine.schema(); - let rows: Vec<_> = engine.scan_memtable_rows().take(limit).collect(); - let mut results = Vec::with_capacity(rows.len()); - - for row in rows { - // Build a nodedb_types::Value::Object directly — no JSON intermediary. - let mut map = std::collections::HashMap::new(); - let mut id = String::new(); - for (i, col_def) in schema.columns.iter().enumerate() { - if i < row.len() { - if col_def.name == "id" - && let nodedb_types::value::Value::String(s) = &row[i] - { - id.clone_from(s); + let mut results = Vec::new(); + + // 1. Read from flushed segments (older rows drained from prior memtable flushes). + if let Some(segments) = self.columnar_flushed_segments.get(collection) { + for seg_bytes in segments { + if results.len() >= limit { + break; + } + let reader = match nodedb_columnar::SegmentReader::open(seg_bytes) { + Ok(r) => r, + Err(e) => { + tracing::warn!(error = %e, "failed to open flushed columnar segment for scan"); + continue; + } + }; + let seg_row_count = reader.row_count() as usize; + let remaining = limit - results.len(); + let take = seg_row_count.min(remaining); + + // Decode all columns for this segment. + let col_count = schema.columns.len(); + let mut decoded_cols = Vec::with_capacity(col_count); + let mut decode_ok = true; + for col_idx in 0..col_count { + match reader.read_column(col_idx) { + Ok(dc) => decoded_cols.push(dc), + Err(e) => { + tracing::warn!(error = %e, col_idx, "failed to decode columnar segment column"); + decode_ok = false; + break; + } + } + } + if !decode_ok { + continue; + } + + for row_idx in 0..take { + let mut map = std::collections::HashMap::new(); + let mut id = String::new(); + for (col_idx, col_def) in schema.columns.iter().enumerate() { + let val = decoded_col_to_value(&decoded_cols[col_idx], row_idx); + if col_def.name == "id" + && let nodedb_types::value::Value::String(s) = &val + { + id.clone_from(s); + } + map.insert(col_def.name.clone(), val); } - map.insert(col_def.name.clone(), row[i].clone()); + let ndb_val = nodedb_types::value::Value::Object(map); + let mp = nodedb_types::value_to_msgpack(&ndb_val).unwrap_or_default(); + results.push((id, mp)); } } - let ndb_val = nodedb_types::value::Value::Object(map); - let mp = nodedb_types::value_to_msgpack(&ndb_val).unwrap_or_default(); - results.push((id, mp)); + } + + // 2. Read from the live memtable (most-recent rows not yet flushed). + if results.len() < limit { + let remaining = limit - results.len(); + let rows: Vec<_> = engine.scan_memtable_rows().take(remaining).collect(); + for row in rows { + let mut map = std::collections::HashMap::new(); + let mut id = String::new(); + for (i, col_def) in schema.columns.iter().enumerate() { + if i < row.len() { + if col_def.name == "id" + && let nodedb_types::value::Value::String(s) = &row[i] + { + id.clone_from(s); + } + map.insert(col_def.name.clone(), row[i].clone()); + } + } + let ndb_val = nodedb_types::value::Value::Object(map); + let mp = nodedb_types::value_to_msgpack(&ndb_val).unwrap_or_default(); + results.push((id, mp)); + } } results @@ -175,3 +232,84 @@ impl CoreLoop { } } } + +/// Convert a single row from a `DecodedColumn` to a `nodedb_types::value::Value`. +/// +/// Returns `Value::Null` if the row index is out of range or the validity bit is false. +fn decoded_col_to_value( + col: &nodedb_columnar::reader::DecodedColumn, + row_idx: usize, +) -> nodedb_types::value::Value { + use nodedb_columnar::reader::DecodedColumn; + use nodedb_types::value::Value; + + match col { + DecodedColumn::Int64 { values, valid } => { + if row_idx < valid.len() && valid[row_idx] { + Value::Integer(values[row_idx]) + } else { + Value::Null + } + } + DecodedColumn::Float64 { values, valid } => { + if row_idx < valid.len() && valid[row_idx] { + Value::Float(values[row_idx]) + } else { + Value::Null + } + } + DecodedColumn::Timestamp { values, valid } => { + if row_idx < valid.len() && valid[row_idx] { + // Represent as integer microseconds (same as Value::Integer for timestamps). + Value::Integer(values[row_idx]) + } else { + Value::Null + } + } + DecodedColumn::Bool { values, valid } => { + if row_idx < valid.len() && valid[row_idx] { + Value::Bool(values[row_idx]) + } else { + Value::Null + } + } + DecodedColumn::Binary { + data, + offsets, + valid, + } => { + if row_idx < valid.len() && valid[row_idx] && row_idx + 1 < offsets.len() { + let start = offsets[row_idx] as usize; + let end = offsets[row_idx + 1] as usize; + if start <= end && end <= data.len() { + let bytes = &data[start..end]; + // Best-effort UTF-8 interpretation; fall back to bytes. + match std::str::from_utf8(bytes) { + Ok(s) => Value::String(s.to_string()), + Err(_) => Value::Bytes(bytes.to_vec()), + } + } else { + Value::Null + } + } else { + Value::Null + } + } + DecodedColumn::DictEncoded { + ids, + dictionary, + valid, + } => { + if row_idx < valid.len() && valid[row_idx] { + let id = ids[row_idx] as usize; + if id < dictionary.len() { + Value::String(dictionary[id].clone()) + } else { + Value::Null + } + } else { + Value::Null + } + } + } +} From ebe478701056e08054c8760e582d1a19df36b8b5 Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Thu, 16 Apr 2026 19:03:10 +0800 Subject: [PATCH 07/10] fix(pgwire): fix extended-query protocol for DSL statements and typed result columns Two issues in the prepared-statement extended-query path: 1. DSL statements (SEARCH, GRAPH, MATCH, UPSERT INTO, etc.) were not handled by the Execute phase. Route them through the same DSL dispatcher used by simple queries; bound parameters are intentionally ignored for DSL. 2. When a statement declares typed result columns via Describe, Execute was producing a single-column JSON response against the N-column schema described to the client, causing null values for columns 2..N. Add a reproject step that parses each JSON object and re-encodes it with one pgwire field per declared column, with missing fields sent as SQL NULL. --- .../src/control/server/pgwire/ddl/backup.rs | 2 +- .../server/pgwire/handler/prepared/execute.rs | 149 +++++++++++++++++- .../server/pgwire/handler/prepared/parser.rs | 26 +++ .../pgwire/handler/prepared/statement.rs | 4 + 4 files changed, 178 insertions(+), 3 deletions(-) diff --git a/nodedb/src/control/server/pgwire/ddl/backup.rs b/nodedb/src/control/server/pgwire/ddl/backup.rs index d96d6979..3e496062 100644 --- a/nodedb/src/control/server/pgwire/ddl/backup.rs +++ b/nodedb/src/control/server/pgwire/ddl/backup.rs @@ -244,7 +244,7 @@ pub async fn restore_tenant( })?; let mut aad = [0u8; nodedb_wal::record::HEADER_SIZE]; aad[..6].copy_from_slice(b"BACKUP"); - key.decrypt(0, &aad, &raw_bytes[4..]) + key.decrypt(key.epoch(), 0, &aad, &raw_bytes[4..]) .map_err(|e| sqlstate_error("XX000", &format!("backup decryption failed: {e}")))? } else { raw_bytes diff --git a/nodedb/src/control/server/pgwire/handler/prepared/execute.rs b/nodedb/src/control/server/pgwire/handler/prepared/execute.rs index a33a6378..d4cd2e77 100644 --- a/nodedb/src/control/server/pgwire/handler/prepared/execute.rs +++ b/nodedb/src/control/server/pgwire/handler/prepared/execute.rs @@ -5,14 +5,17 @@ //! all DDL dispatch, transaction handling, and permission checks. use std::fmt::Debug; +use std::sync::Arc; use bytes::Bytes; +use futures::StreamExt; use futures::sink::Sink; use pgwire::api::portal::Portal; -use pgwire::api::results::Response; +use pgwire::api::results::{DataRowEncoder, FieldInfo, QueryResponse, Response}; use pgwire::api::{ClientInfo, ClientPortalStore, Type}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pgwire::messages::PgWireBackendMessage; +use sonic_rs; use super::super::core::NodeDbPgHandler; use super::statement::ParsedStatement; @@ -39,6 +42,15 @@ impl NodeDbPgHandler { let stmt = &portal.statement.statement; let tenant_id = identity.tenant_id; + // DSL passthroughs (SEARCH, GRAPH, MATCH, UPSERT INTO, etc.) cannot be + // handled by the planned-SQL path. Route them through the same full DSL + // dispatcher used by the simple-query handler. DSL statements do not use + // SQL parameter placeholders, so bound parameters are intentionally ignored. + if stmt.is_dsl { + let mut results = self.execute_sql(&identity, &addr, &stmt.sql).await?; + return Ok(results.pop().unwrap_or(Response::EmptyQuery)); + } + // Convert pgwire binary parameters to typed ParamValues for AST binding. let params = convert_portal_params(&portal.parameters, &stmt.param_types)?; @@ -46,8 +58,118 @@ impl NodeDbPgHandler { let mut results = self .execute_planned_sql_with_params(&identity, &stmt.sql, tenant_id, &addr, ¶ms) .await?; - Ok(results.pop().unwrap_or(Response::EmptyQuery)) + let result = results.pop().unwrap_or(Response::EmptyQuery); + + // When the statement declared typed result columns via Describe, the + // client expects DataRow messages with one field per declared column. + // + // The generic `payload_to_response` path produces a single-column + // QueryResponse with the full JSON as one text field. In the extended- + // query protocol the RowDescription was already sent by Describe, so + // pgwire sends only the DataRow messages on Execute — the client maps + // them against the previously-described schema. A 1-field row against + // an N-column schema causes null values for columns 2..N. + // + // Fix: when result_fields is non-empty, consume the single-field stream, + // parse each JSON object, and re-encode with one pgwire field per + // declared column. + if !stmt.result_fields.is_empty() { + reproject_response(result, &stmt.result_fields).await + } else { + Ok(result) + } + } +} + +/// Re-encode a query response to match the column schema declared by Describe. +/// +/// Each DataRow from `payload_to_response` contains a single text field holding +/// a JSON object. We parse each object and extract fields in `result_fields` +/// order, producing a new QueryResponse whose rows have one field per declared +/// column. Missing fields are sent as SQL NULL. +/// +/// Non-query responses (execution tags) pass through unchanged. +async fn reproject_response( + response: Response, + result_fields: &[FieldInfo], +) -> PgWireResult { + let qr = match response { + Response::Query(qr) => qr, + other => return Ok(other), + }; + + let schema = Arc::new(result_fields.to_vec()); + let field_names: Vec = result_fields.iter().map(|f| f.name().to_string()).collect(); + + // Collect JSON objects from the single-column stream produced by + // payload_to_response. Each DataRow has exactly one field: a JSON string. + let json_rows = collect_json_rows(qr).await?; + + let mut pgwire_rows = Vec::with_capacity(json_rows.len()); + for obj in &json_rows { + let mut encoder = DataRowEncoder::new(schema.clone()); + for name in &field_names { + match obj.get(name) { + None | Some(serde_json::Value::Null) => { + let _ = encoder.encode_field(&Option::::None); + } + Some(v) => { + let text = match v { + serde_json::Value::String(s) => s.clone(), + other => other.to_string(), + }; + let _ = encoder.encode_field(&text); + } + } + } + pgwire_rows.push(Ok(encoder.take_row())); } + + Ok(Response::Query(QueryResponse::new( + schema, + futures::stream::iter(pgwire_rows), + ))) +} + +/// Consume a `QueryResponse` stream and decode the single text field of each +/// `DataRow` as a JSON object. +/// +/// `payload_to_response` always produces rows where field[0] is a JSON string. +/// The pgwire `DataRow.data` format is: for each field, 4-byte length (i32, +/// big-endian) followed by the field bytes. `-1` (0xFFFFFFFF) means SQL NULL. +async fn collect_json_rows(mut qr: QueryResponse) -> PgWireResult> { + let mut rows = Vec::new(); + while let Some(row_result) = qr.data_rows.next().await { + let row = row_result?; + // Decode field[0] from the raw DataRow wire format. + let text = decode_first_field_text(&row.data); + if let Some(t) = text { + let val: serde_json::Value = + sonic_rs::from_str(t).unwrap_or_else(|_| serde_json::Value::String(t.to_string())); + rows.push(val); + } + } + Ok(rows) +} + +/// Decode the text bytes of the first field from a pgwire `DataRow` wire buffer. +/// +/// Wire format: for each field, 4-byte big-endian length followed by bytes. +/// Returns `None` for NULL fields or invalid encodings. +fn decode_first_field_text(data: &bytes::BytesMut) -> Option<&str> { + if data.len() < 4 { + return None; + } + let len = i32::from_be_bytes([data[0], data[1], data[2], data[3]]); + if len < 0 { + // NULL field. + return None; + } + let len = len as usize; + if data.len() < 4 + len { + return None; + } + std::str::from_utf8(&data[4..4 + len]).ok() } /// Convert pgwire portal parameters to typed `ParamValue` for AST-level binding. @@ -156,4 +278,27 @@ mod tests { assert!(matches!(result[0], nodedb_sql::ParamValue::Bool(v) if v == expected)); } } + + #[test] + fn decode_first_field_text_normal() { + // Wire format: 4-byte length (big-endian) + UTF-8 bytes. + let text = b"hello"; + let mut data = bytes::BytesMut::new(); + data.extend_from_slice(&(text.len() as i32).to_be_bytes()); + data.extend_from_slice(text); + assert_eq!(decode_first_field_text(&data), Some("hello")); + } + + #[test] + fn decode_first_field_text_null() { + // -1 length means SQL NULL. + let mut data = bytes::BytesMut::new(); + data.extend_from_slice(&(-1i32).to_be_bytes()); + assert_eq!(decode_first_field_text(&data), None); + } + + #[test] + fn decode_first_field_text_empty() { + assert_eq!(decode_first_field_text(&bytes::BytesMut::new()), None); + } } diff --git a/nodedb/src/control/server/pgwire/handler/prepared/parser.rs b/nodedb/src/control/server/pgwire/handler/prepared/parser.rs index d24f5a52..03d37585 100644 --- a/nodedb/src/control/server/pgwire/handler/prepared/parser.rs +++ b/nodedb/src/control/server/pgwire/handler/prepared/parser.rs @@ -112,10 +112,17 @@ impl QueryParser for NodeDbQueryParser { .unwrap_or(1); let (param_types, result_fields) = self.try_infer_types(sql, types, tenant_id); + // If type inference produced no result fields and the SQL matches a + // known DSL prefix, mark the statement as a DSL passthrough. The + // Execute handler will route it through the full DSL dispatcher + // (same as the simple-query path) instead of `execute_planned_sql_with_params`. + let is_dsl = result_fields.is_empty() && is_dsl_statement(sql); + Ok(ParsedStatement { sql: sql.to_owned(), param_types, result_fields, + is_dsl, }) } @@ -136,6 +143,25 @@ impl QueryParser for NodeDbQueryParser { } } +/// Return true if `sql` starts with a DSL keyword that `plan_sql` cannot parse. +/// +/// Mirrors the prefix checks in `ddl/router/dsl.rs` so the extended-query +/// Parse handler can mark such statements as DSL passthroughs and route them +/// through the DSL dispatcher at Execute time. +fn is_dsl_statement(sql: &str) -> bool { + let upper = sql.trim().to_uppercase(); + upper.starts_with("SEARCH ") + || upper.starts_with("GRAPH ") + || upper.starts_with("MATCH ") + || upper.starts_with("OPTIONAL MATCH ") + || upper.starts_with("CRDT MERGE ") + || upper.starts_with("UPSERT INTO ") + || upper.starts_with("CREATE VECTOR INDEX ") + || upper.starts_with("CREATE FULLTEXT INDEX ") + || upper.starts_with("CREATE SEARCH INDEX ") + || upper.starts_with("CREATE SPARSE INDEX ") +} + /// Count $1, $2, ... placeholders in SQL text. fn count_placeholders(sql: &str) -> usize { let mut max_idx = 0usize; diff --git a/nodedb/src/control/server/pgwire/handler/prepared/statement.rs b/nodedb/src/control/server/pgwire/handler/prepared/statement.rs index cbbde8fa..ed5b0b45 100644 --- a/nodedb/src/control/server/pgwire/handler/prepared/statement.rs +++ b/nodedb/src/control/server/pgwire/handler/prepared/statement.rs @@ -21,4 +21,8 @@ pub struct ParsedStatement { /// Result column schema inferred from the logical plan. /// Empty for DML statements (INSERT/UPDATE/DELETE). pub result_fields: Vec, + /// True when the SQL is a DSL statement (SEARCH, GRAPH, MATCH, UPSERT INTO, + /// etc.) that `plan_sql` cannot parse. The Execute handler routes these + /// through the full DSL dispatcher instead of `execute_planned_sql_with_params`. + pub is_dsl: bool, } From 14fb1d6b48fd57197d405efe19a874e78f04c763 Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Thu, 16 Apr 2026 19:03:39 +0800 Subject: [PATCH 08/10] test: add integration tests for columnar aggregates and prepared statements Cover GROUP BY with sum/avg/min/max/count over flushed columnar segments, and extended-query protocol correctness for typed result columns and DSL statement passthrough. --- .../executor_tests/test_columnar_aggregate.rs | 125 ++++++++++++++++++ nodedb/tests/sql_prepared_statements.rs | 83 ++++++++++++ 2 files changed, 208 insertions(+) diff --git a/nodedb/tests/executor_tests/test_columnar_aggregate.rs b/nodedb/tests/executor_tests/test_columnar_aggregate.rs index c01a23bc..ebd56060 100644 --- a/nodedb/tests/executor_tests/test_columnar_aggregate.rs +++ b/nodedb/tests/executor_tests/test_columnar_aggregate.rs @@ -124,3 +124,128 @@ fn columnar_having_uses_canonical_key_but_output_keeps_user_alias() { assert_eq!(rows[0]["city_count"].as_u64(), Some(2)); assert!(rows[0].get("count(*)").is_none()); } + +#[test] +fn columnar_insert_triggers_memtable_flush() { + // Spec: after inserting more rows than DEFAULT_FLUSH_THRESHOLD (65536), the + // memtable must be drained to a segment on disk rather than accumulating + // unbounded memory. + let mut ctx = make_ctx(); + + // Build a batch of 70000 rows — above the 65536 flush threshold. + let rows: Vec = (0..70_000) + .map(|i| { + serde_json::json!({ + "id": format!("r{i}"), + "v": i, + }) + }) + .collect(); + let payload = nodedb_types::json_to_msgpack(&serde_json::Value::Array(rows)).unwrap(); + + // The write must succeed without error. Before the fix this would succeed + // but silently accumulate all rows in RAM; after the fix the engine flushes + // the memtable to a segment once the threshold is crossed. + send_ok( + &mut ctx.core, + &mut ctx.tx, + &mut ctx.rx, + PhysicalPlan::Columnar(ColumnarOp::Insert { + collection: "large_col".into(), + payload, + format: "msgpack".into(), + }), + ); + + // All rows must be readable back — the segment flush must not lose data. + let doc_count = ctx + .core + .scan_collection(1, "large_col", 70_001) + .unwrap() + .len(); + assert_eq!( + doc_count, 70_000, + "all inserted rows must be scannable after flush" + ); +} + +#[test] +fn aggregate_group_by_does_not_require_full_materialization() { + // Spec: GROUP BY aggregation must return correct per-group results regardless + // of whether the implementation uses running aggregates (O(groups)) or + // full doc materialization (O(rows)). This test locks in correctness; + // the fix changes internal memory usage from O(N) to O(groups). + let mut ctx = make_ctx(); + + // Insert 1000 rows across 10 groups (g0..g9), each group gets 100 rows. + let rows: Vec = (0..1_000) + .map(|i| { + serde_json::json!({ + "id": format!("r{i}"), + "g": format!("g{}", i % 10), + "v": i, + }) + }) + .collect(); + let payload = nodedb_types::json_to_msgpack(&serde_json::Value::Array(rows)).unwrap(); + + send_ok( + &mut ctx.core, + &mut ctx.tx, + &mut ctx.rx, + PhysicalPlan::Columnar(ColumnarOp::Insert { + collection: "grouped".into(), + payload, + format: "msgpack".into(), + }), + ); + + let payload = send_ok( + &mut ctx.core, + &mut ctx.tx, + &mut ctx.rx, + PhysicalPlan::Query(QueryOp::Aggregate { + collection: "grouped".into(), + group_by: vec!["g".into()], + aggregates: vec![ + AggregateSpec { + function: "count".into(), + alias: "count(*)".into(), + user_alias: None, + field: "*".into(), + expr: None, + }, + AggregateSpec { + function: "sum".into(), + alias: "sum(v)".into(), + user_alias: None, + field: "v".into(), + expr: None, + }, + ], + filters: Vec::new(), + having: Vec::new(), + limit: 100, + sub_group_by: Vec::new(), + sub_aggregates: Vec::new(), + }), + ); + + let result = payload_value(&payload); + let result_rows = result + .as_array() + .unwrap_or_else(|| panic!("expected aggregate rows, got {result}")); + + assert_eq!( + result_rows.len(), + 10, + "GROUP BY must produce exactly 10 groups" + ); + for row in result_rows { + assert_eq!( + row["count(*)"].as_u64(), + Some(100), + "each group must contain exactly 100 rows, got: {row}" + ); + } +} diff --git a/nodedb/tests/sql_prepared_statements.rs b/nodedb/tests/sql_prepared_statements.rs index ace3a7a2..44eadf75 100644 --- a/nodedb/tests/sql_prepared_statements.rs +++ b/nodedb/tests/sql_prepared_statements.rs @@ -23,3 +23,86 @@ async fn prepare_execute_deallocate_lifecycle() { server.exec("DEALLOCATE ALL").await.unwrap(); server.expect_error("EXECUTE q1", "does not exist").await; } + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn prepared_search_vector_dsl() { + let server = TestServer::start().await; + + // Create a document collection and a vector index on the embedding field. + server + .exec("CREATE COLLECTION vec_ep TYPE document") + .await + .unwrap(); + server + .exec("CREATE VECTOR INDEX idx_vec_ep ON vec_ep METRIC cosine DIM 3") + .await + .unwrap(); + + // Insert a document with an embedding vector. + server + .exec("INSERT INTO vec_ep (id, embedding) VALUES ('v1', ARRAY[1.0, 0.0, 0.0])") + .await + .unwrap(); + + // DSL SEARCH statements must not be rejected by the extended-protocol path + // with "Expected: an SQL statement". The statement should succeed and return + // results (or an empty result set — the key is no parse-time rejection). + let result = server + .query_text("SEARCH vec_ep USING VECTOR(embedding, ARRAY[1.0, 0.0, 0.0], 3)") + .await; + assert!( + result.is_ok(), + "SEARCH via extended protocol must not be rejected: {:?}", + result.err() + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn prepared_upsert_dsl() { + let server = TestServer::start().await; + + server.exec("CREATE COLLECTION upsert_ep").await.unwrap(); + + // UPSERT INTO DSL statements must not be rejected by the extended-protocol + // path with "Expected: an SQL statement". + let result = server + .exec("UPSERT INTO upsert_ep { id: 'u1', name: 'alice' }") + .await; + assert!( + result.is_ok(), + "UPSERT INTO via extended protocol must not be rejected: {:?}", + result.err() + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn prepared_select_strict_doc_returns_data() { + let server = TestServer::start().await; + + server + .exec( + "CREATE COLLECTION strict_ep TYPE DOCUMENT STRICT \ + (id TEXT PRIMARY KEY, name TEXT)", + ) + .await + .unwrap(); + server + .exec("INSERT INTO strict_ep (id, name) VALUES ('a', 'alice')") + .await + .unwrap(); + + // SELECT on a STRICT doc collection via the extended-query protocol must + // return the inserted row with actual column values, not null/empty columns. + let rows = server + .query_text("SELECT id, name FROM strict_ep WHERE id = 'a'") + .await + .unwrap(); + assert!(!rows.is_empty(), "SELECT should return the inserted row"); + + // Regression guard: the row must contain actual data, not null. + assert!( + rows[0].contains("alice"), + "extended protocol must not return null columns for STRICT doc, got: {:?}", + rows[0] + ); +} From d8987d303e2440ba256e8e95b80f6e6ca00def65 Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Thu, 16 Apr 2026 19:21:13 +0800 Subject: [PATCH 09/10] refactor(aggregate): flatten nested conditionals in accumulator hot path Replace nested if-let and if-inside-if-let patterns with let-chain conditions in AggAccum::accumulate. Eliminates unnecessary nesting, makes null-skip and dedup conditions read left-to-right, and avoids a spurious clone of COUNT(*) row in the non-null branch. Also replace manual .max(64).min(MAX_EF_SEARCH) with .clamp() in the ef_search sizing helper for clarity. --- nodedb/src/data/executor/handlers/accum.rs | 74 +++++++++---------- .../data/executor/handlers/vector_search.rs | 2 +- 2 files changed, 37 insertions(+), 39 deletions(-) diff --git a/nodedb/src/data/executor/handlers/accum.rs b/nodedb/src/data/executor/handlers/accum.rs index 3809bf50..e269d72f 100644 --- a/nodedb/src/data/executor/handlers/accum.rs +++ b/nodedb/src/data/executor/handlers/accum.rs @@ -118,9 +118,9 @@ impl AggAccum { use nodedb_query::msgpack_scan::aggregate_helpers as ah; match self { AggAccum::Count { n } => { - if agg.field == "*" && agg.expr.is_none() { - *n += 1; - } else if ah::extract_non_null(doc, &agg.field, agg.expr.as_ref()).is_some() { + if (agg.field == "*" && agg.expr.is_none()) + || ah::extract_non_null(doc, &agg.field, agg.expr.as_ref()).is_some() + { *n += 1; } } @@ -168,10 +168,10 @@ impl AggAccum { } } AggAccum::CountDistinct { seen } => { - if let Some(bytes) = ah::extract_bytes(doc, &agg.field, agg.expr.as_ref()) { - if bytes != [0xc0u8] { - seen.insert(bytes); - } + if let Some(bytes) = ah::extract_bytes(doc, &agg.field, agg.expr.as_ref()) + && bytes != [0xc0u8] + { + seen.insert(bytes); } } AggAccum::Welford { n, mean, m2 } => { @@ -184,10 +184,10 @@ impl AggAccum { } } AggAccum::Hll { hll } => { - if let Some(bytes) = ah::extract_bytes(doc, &agg.field, agg.expr.as_ref()) { - if bytes != [0xc0u8] { - hll.add(fnv1a(&bytes)); - } + if let Some(bytes) = ah::extract_bytes(doc, &agg.field, agg.expr.as_ref()) + && bytes != [0xc0u8] + { + hll.add(fnv1a(&bytes)); } } AggAccum::TDigest { digest } => { @@ -198,45 +198,43 @@ impl AggAccum { } AggAccum::TopK { ss, .. } => { let actual = field_after_colon(&agg.field); - if let Some(bytes) = ah::extract_bytes(doc, actual, agg.expr.as_ref()) { - if bytes != [0xc0u8] { - ss.add(fnv1a(&bytes)); - } + if let Some(bytes) = ah::extract_bytes(doc, actual, agg.expr.as_ref()) + && bytes != [0xc0u8] + { + ss.add(fnv1a(&bytes)); } } AggAccum::ArrayAgg { values } => { - if values.len() < ARRAY_AGG_CAP { - if let Some(v) = ah::extract_value(doc, &agg.field, agg.expr.as_ref()) { - if !v.is_null() { - values.push(v); - } - } + if values.len() < ARRAY_AGG_CAP + && let Some(v) = ah::extract_value(doc, &agg.field, agg.expr.as_ref()) + && !v.is_null() + { + values.push(v); } } AggAccum::ArrayAggDistinct { seen, values } => { - if values.len() < ARRAY_AGG_CAP { - if let Some(bytes) = ah::extract_bytes(doc, &agg.field, agg.expr.as_ref()) { - if bytes != [0xc0u8] && seen.insert(bytes) { - if let Some(v) = ah::extract_value(doc, &agg.field, agg.expr.as_ref()) { - values.push(v); - } - } - } + if values.len() < ARRAY_AGG_CAP + && let Some(bytes) = ah::extract_bytes(doc, &agg.field, agg.expr.as_ref()) + && bytes != [0xc0u8] + && seen.insert(bytes) + && let Some(v) = ah::extract_value(doc, &agg.field, agg.expr.as_ref()) + { + values.push(v); } } AggAccum::PercentileCont { values, .. } => { - if values.len() < ARRAY_AGG_CAP { - let actual = field_after_colon(&agg.field); - if let Some(v) = ah::extract_f64(doc, actual, agg.expr.as_ref()) { - values.push(v); - } + let actual = field_after_colon(&agg.field); + if values.len() < ARRAY_AGG_CAP + && let Some(v) = ah::extract_f64(doc, actual, agg.expr.as_ref()) + { + values.push(v); } } AggAccum::StringAgg { parts } => { - if parts.len() < ARRAY_AGG_CAP { - if let Some(s) = ah::extract_str(doc, &agg.field, agg.expr.as_ref()) { - parts.push(s); - } + if parts.len() < ARRAY_AGG_CAP + && let Some(s) = ah::extract_str(doc, &agg.field, agg.expr.as_ref()) + { + parts.push(s); } } } diff --git a/nodedb/src/data/executor/handlers/vector_search.rs b/nodedb/src/data/executor/handlers/vector_search.rs index 88941456..c2773df3 100644 --- a/nodedb/src/data/executor/handlers/vector_search.rs +++ b/nodedb/src/data/executor/handlers/vector_search.rs @@ -359,6 +359,6 @@ fn effective_ef(ef_search: usize, top_k: usize) -> usize { if ef_search > 0 { ef_search.max(top_k).min(MAX_EF_SEARCH) } else { - top_k.saturating_mul(4).max(64).min(MAX_EF_SEARCH) + top_k.saturating_mul(4).clamp(64, MAX_EF_SEARCH) } } From 09bf7de46327bbe48bf8c3988b5f1987b374bec6 Mon Sep 17 00:00:00 2001 From: Farhan Syah Date: Thu, 16 Apr 2026 19:21:20 +0800 Subject: [PATCH 10/10] fix: replace unsafe arithmetic chains with explicit early-return bounds checks nodedb-vector/src/mmap_segment.rs: rewrite the mmap offset bounds check from a single nested checked_add/checked_mul expression into sequential early-returns via let-else. Each overflow or out-of-bounds condition is now its own guard, making the failure paths obvious. nodedb-wal/src/crypto.rs: replace .clone() calls on Copy epoch values in tests with copy-dereference to avoid unnecessary clone on a type that implements Copy. --- nodedb-vector/src/mmap_segment.rs | 17 +++++++++++------ nodedb-wal/src/crypto.rs | 12 ++++++------ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/nodedb-vector/src/mmap_segment.rs b/nodedb-vector/src/mmap_segment.rs index cf37f4e7..0c8947fb 100644 --- a/nodedb-vector/src/mmap_segment.rs +++ b/nodedb-vector/src/mmap_segment.rs @@ -171,13 +171,18 @@ impl MmapVectorSegment { Some(v) => v, None => return, }; - let offset = match self - .data_offset - .checked_add(idx.checked_mul(byte_len).unwrap_or(usize::MAX)) - { - Some(v) if v.checked_add(byte_len).is_some_and(|e| e <= self.mmap_size) => v, - _ => return, + let Some(idx_bytes) = idx.checked_mul(byte_len) else { + return; + }; + let Some(offset) = self.data_offset.checked_add(idx_bytes) else { + return; }; + if offset + .checked_add(byte_len) + .is_none_or(|e| e > self.mmap_size) + { + return; + } let page_start = offset & !(4095); let len = (byte_len + 4095) & !(4095); unsafe { diff --git a/nodedb-wal/src/crypto.rs b/nodedb-wal/src/crypto.rs index d19a23fe..9bf525cd 100644 --- a/nodedb-wal/src/crypto.rs +++ b/nodedb-wal/src/crypto.rs @@ -227,7 +227,7 @@ mod tests { #[test] fn encrypt_decrypt_roundtrip() { let key = test_key(); - let epoch = key.epoch().clone(); + let epoch = *key.epoch(); let header = test_header(1); let plaintext = b"hello nodedb encryption"; @@ -242,7 +242,7 @@ mod tests { #[test] fn wrong_key_fails() { let key1 = WalEncryptionKey::from_bytes(&[0x01; 32]); - let epoch1 = key1.epoch().clone(); + let epoch1 = *key1.epoch(); let key2 = WalEncryptionKey::from_bytes(&[0x02; 32]); let header = test_header(1); @@ -253,7 +253,7 @@ mod tests { #[test] fn wrong_lsn_fails() { let key = test_key(); - let epoch = key.epoch().clone(); + let epoch = *key.epoch(); let header = test_header(1); let ciphertext = key.encrypt(1, &header, b"secret").unwrap(); @@ -264,7 +264,7 @@ mod tests { #[test] fn tampered_ciphertext_fails() { let key = test_key(); - let epoch = key.epoch().clone(); + let epoch = *key.epoch(); let header = test_header(1); let mut ciphertext = key.encrypt(1, &header, b"secret").unwrap(); @@ -275,7 +275,7 @@ mod tests { #[test] fn tampered_header_fails() { let key = test_key(); - let epoch = key.epoch().clone(); + let epoch = *key.epoch(); let header1 = test_header(1); let ciphertext = key.encrypt(1, &header1, b"secret").unwrap(); @@ -289,7 +289,7 @@ mod tests { #[test] fn empty_payload() { let key = test_key(); - let epoch = key.epoch().clone(); + let epoch = *key.epoch(); let header = test_header(1); let ciphertext = key.encrypt(1, &header, b"").unwrap();