From b8228ccd5a99068035235926890a6e817fd5460f Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Wed, 8 Oct 2025 20:43:02 -0500 Subject: [PATCH 01/15] parse filters --- server/src/routes/search.rs | 127 +++++++++++++++++++++++++++++++++--- server/tests/tests.rs | 92 ++++++++++++++++++++++++-- 2 files changed, 207 insertions(+), 12 deletions(-) diff --git a/server/src/routes/search.rs b/server/src/routes/search.rs index 5bdd3e10..812eeab6 100644 --- a/server/src/routes/search.rs +++ b/server/src/routes/search.rs @@ -13,6 +13,113 @@ use vectorize_core::transformers::providers::prepare_generic_embedding_request; use vectorize_core::transformers::types::Inputs; use vectorize_core::types::VectorizeJob; +/// Filter operators supported by the search API +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum FilterOperator { + /// Equal to (=) + #[serde(rename = "eq")] + Equal, + /// Greater than (>) + #[serde(rename = "gt")] + GreaterThan, + /// Greater than or equal (>=) + #[serde(rename = "gte")] + GreaterThanOrEqual, + /// Less than (<) + #[serde(rename = "lt")] + LessThan, + /// Less than or equal (<=) + #[serde(rename = "lte")] + LessThanOrEqual, +} + +impl FilterOperator { + /// Convert operator to SQL operator string + pub fn to_sql(&self) -> &'static str { + match self { + FilterOperator::Equal => "=", + FilterOperator::GreaterThan => ">", + FilterOperator::GreaterThanOrEqual => ">=", + FilterOperator::LessThan => "<", + FilterOperator::LessThanOrEqual => "<=", + } + } +} + +/// A filter value with an operator +#[derive(Debug, Clone, Serialize)] +pub struct FilterValue { + pub operator: FilterOperator, + pub value: String, +} + +impl FilterValue { + /// Convert to the original FilterValue format for compatibility with vectorize_core + pub fn to_legacy_filter_value(&self) -> query::FilterValue { + // For now, we'll convert to String variant - this may need to be updated + // based on the actual FilterValue enum structure in vectorize_core + query::FilterValue::String(self.value.clone()) + } +} + +/// Custom deserializer for FilterValue that parses operator.value format +impl<'de> Deserialize<'de> for FilterValue { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::{self, Visitor}; + use std::fmt; + + struct FilterValueVisitor; + + impl<'de> Visitor<'de> for FilterValueVisitor { + type Value = FilterValue; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string in format 'operator.value' or just 'value'") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + if let Some(dot_pos) = value.find('.') { + let operator_str = &value[..dot_pos]; + let val = &value[dot_pos + 1..]; + + let operator = match operator_str { + "eq" => FilterOperator::Equal, + "gt" => FilterOperator::GreaterThan, + "gte" => FilterOperator::GreaterThanOrEqual, + "lt" => FilterOperator::LessThan, + "lte" => FilterOperator::LessThanOrEqual, + _ => { + return Err(de::Error::custom(format!( + "Unknown operator: {}", + operator_str + ))); + } + }; + + Ok(FilterValue { + operator, + value: val.to_string(), + }) + } else { + // Default to equality if no operator specified + Ok(FilterValue { + operator: FilterOperator::Equal, + value: value.to_string(), + }) + } + } + } + + deserializer.deserialize_str(FilterValueVisitor) + } +} + #[derive(Serialize, Deserialize, Debug, Clone, ToSchema, FromRow)] pub struct SearchRequest { pub job_name: String, @@ -28,7 +135,7 @@ pub struct SearchRequest { #[serde(default = "default_fts_wt")] pub fts_wt: f32, #[serde(flatten, default)] - pub filters: BTreeMap, + pub filters: BTreeMap, } fn default_semantic_wt() -> f32 { @@ -89,10 +196,7 @@ pub async fn search( for (key, value) in &payload.filters { // validate key and value query::check_input(key)?; - if let query::FilterValue::String(value) = value { - // only need to check the value if it is a raw string - query::check_input(value)?; - } + query::check_input(&value.value)?; } } @@ -131,6 +235,13 @@ pub async fn search( let embedding_request = prepare_generic_embedding_request(&vectorizejob.model, &[input]); let embeddings = provider.generate_embedding(&embedding_request).await?; + // Convert our new filters to legacy format for compatibility with vectorize_core + let legacy_filters: BTreeMap = payload + .filters + .iter() + .map(|(key, value)| (key.clone(), value.to_legacy_filter_value())) + .collect(); + let q = query::hybrid_search_query( &payload.job_name, &vectorizejob.src_schema, @@ -142,15 +253,15 @@ pub async fn search( payload.rrf_k, payload.semantic_wt, payload.fts_wt, - &payload.filters, + &legacy_filters, ); let mut prepared_query = sqlx::query(&q) .bind(&embeddings.embeddings[0]) .bind(&payload.query); - // Bind filter values using the same BTreeMap instance used by the query builder - for value in payload.filters.values() { + // Bind filter values using the legacy format + for value in legacy_filters.values() { prepared_query = value.bind_to_query(prepared_query); } diff --git a/server/tests/tests.rs b/server/tests/tests.rs index 3c92a211..811ca553 100644 --- a/server/tests/tests.rs +++ b/server/tests/tests.rs @@ -149,8 +149,8 @@ async fn test_search_filters() { resp.status() ); - // filter a query by product_category - let params = format!("job_name={job_name}&query=pen&product_category=electronics",); + // filter a query by product_category (using eq operator) + let params = format!("job_name={job_name}&query=pen&product_category=eq.electronics",); let search_results = common::search_with_retry(¶ms, 9).await.unwrap(); assert_eq!(search_results.len(), 9); @@ -159,8 +159,8 @@ async fn test_search_filters() { assert_eq!(result["product_category"].as_str().unwrap(), "electronics"); } - // filter by price - let params = format!("job_name={job_name}&query=electronics&price=25"); + // filter by price using less than operator + let params = format!("job_name={job_name}&query=electronics&price=lt.30"); let search_results = common::search_with_retry(¶ms, 2).await.unwrap(); assert_eq!(search_results.len(), 2); assert_eq!( @@ -171,6 +171,90 @@ async fn test_search_filters() { search_results[1]["product_name"].as_str().unwrap(), "Alarm Clock" ); + + // test greater than or equal operator + let params = format!("job_name={job_name}&query=electronics&price=gte.25"); + let search_results = common::search_with_retry(¶ms, 2).await.unwrap(); + assert_eq!(search_results.len(), 2); + + // test backward compatibility - no operator should default to equality + let params = format!("job_name={job_name}&query=pen&product_category=electronics",); + let search_results = common::search_with_retry(¶ms, 9).await.unwrap(); + assert_eq!(search_results.len(), 9); + for result in search_results { + assert_eq!(result["product_category"].as_str().unwrap(), "electronics"); + } +} + +#[tokio::test] +async fn test_search_filter_operators() { + let mut rng = rand::rng(); + let test_num = rng.random_range(1..100000); + let cfg = vectorize_core::config::Config::from_env(); + let sql = std::fs::read_to_string("sql/example.sql").unwrap(); + common::exec_psql(&cfg.database_url, &sql); + + let pool = sqlx::PgPool::connect(&cfg.database_url).await.unwrap(); + // test table + let table = format!("test_filter_ops_{test_num}"); + let drop_sql = format!("DROP TABLE IF EXISTS public.{table};"); + let create_sql = + format!("CREATE TABLE public.{table} (LIKE public.my_products INCLUDING ALL);"); + let insert_sql = format!("INSERT INTO public.{table} SELECT * FROM public.my_products;"); + + sqlx::query(&drop_sql).execute(&pool).await.unwrap(); + sqlx::query(&create_sql).execute(&pool).await.unwrap(); + sqlx::query(&insert_sql).execute(&pool).await.unwrap(); + + // initialize search job + let job_name = format!("test_filter_ops_{test_num}"); + let payload = json!({ + "job_name": job_name, + "src_table": table, + "src_schema": "public", + "src_columns": ["description"], + "primary_key": "product_id", + "update_time_col": "updated_at", + "model": "sentence-transformers/all-MiniLM-L6-v2" + }); + + let client = reqwest::Client::new(); + let resp = client + .post("http://localhost:8080/api/v1/table") + .header("Content-Type", "application/json") + .json(&payload) + .send() + .await + .expect("Failed to send request"); + + assert_eq!( + resp.status(), + reqwest::StatusCode::OK, + "Response status: {:?}", + resp.status() + ); + + // Test different operators + // Greater than + let params = format!("job_name={job_name}&query=electronics&price=gt.20"); + let search_results = common::search_with_retry(¶ms, 3).await.unwrap(); + assert_eq!(search_results.len(), 3); + + // Less than or equal + let params = format!("job_name={job_name}&query=electronics&price=lte.25"); + let search_results = common::search_with_retry(¶ms, 2).await.unwrap(); + assert_eq!(search_results.len(), 2); + + // Test invalid operator (should return error) + let params = format!("job_name={job_name}&query=electronics&price=invalid.25"); + let response = client + .get(&format!("http://localhost:8080/api/v1/search?{}", params)) + .send() + .await + .expect("Failed to send request"); + + // Should return an error for invalid operator + assert!(response.status().is_client_error() || response.status().is_server_error()); } /// proxy is an incomplete feature From 54ede5f7a2c99deab514d0c8901588078e9d01fe Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Wed, 8 Oct 2025 20:50:52 -0500 Subject: [PATCH 02/15] refact --- core/src/query.rs | 123 +++++++++++++++++++++++------------ server/src/routes/search.rs | 126 ++---------------------------------- 2 files changed, 90 insertions(+), 159 deletions(-) diff --git a/core/src/query.rs b/core/src/query.rs index 3cc32afb..54cdaf58 100644 --- a/core/src/query.rs +++ b/core/src/query.rs @@ -7,59 +7,101 @@ use sqlx::postgres::PgRow; use sqlx::{Postgres, Row}; use std::collections::BTreeMap; use tiktoken_rs::cl100k_base; -pub const VECTORIZE_SCHEMA: &str = "vectorize"; -static TRIGGER_FN_PREFIX: &str = "vectorize.handle_update_"; -#[derive(Serialize, Debug, Clone)] -#[serde(untagged)] -pub enum FilterValue { - String(String), - Integer(i64), - Float(f64), - Boolean(bool), +/// Filter operators supported by the search API +#[derive(Debug, Clone, PartialEq, Serialize)] +pub enum FilterOperator { + /// Equal to (=) + Equal, + /// Greater than (>) + GreaterThan, + /// Greater than or equal (>=) + GreaterThanOrEqual, + /// Less than (<) + LessThan, + /// Less than or equal (<=) + LessThanOrEqual, +} + +impl FilterOperator { + /// Convert operator to SQL operator string + pub fn to_sql(&self) -> &'static str { + match self { + FilterOperator::Equal => "=", + FilterOperator::GreaterThan => ">", + FilterOperator::GreaterThanOrEqual => ">=", + FilterOperator::LessThan => "<", + FilterOperator::LessThanOrEqual => "<=", + } + } } +/// A filter value with an operator +#[derive(Debug, Clone, Serialize)] +pub struct FilterValue { + pub operator: FilterOperator, + pub value: String, +} + +/// Custom deserializer for FilterValue that parses operator.value format impl<'de> serde::Deserialize<'de> for FilterValue { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { - let s = String::deserialize(deserializer)?; - - // Try to parse as boolean first - if let Ok(b) = s.parse::() { - return Ok(FilterValue::Boolean(b)); - } + use serde::de::{self, Visitor}; + use std::fmt; - // Try to parse as integer - if let Ok(i) = s.parse::() { - return Ok(FilterValue::Integer(i)); - } + struct FilterValueVisitor; - // Try to parse as float - if let Ok(f) = s.parse::() { - return Ok(FilterValue::Float(f)); - } + impl<'de> Visitor<'de> for FilterValueVisitor { + type Value = FilterValue; - // Fall back to string - - Ok(FilterValue::String(s)) - } -} + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string in format 'operator.value' or just 'value'") + } -impl FilterValue { - pub fn bind_to_query<'q>( - &'q self, - query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>, - ) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> { - match self { - FilterValue::String(s) => query.bind(s), - FilterValue::Integer(i) => query.bind(*i), - FilterValue::Float(f) => query.bind(*f), - FilterValue::Boolean(b) => query.bind(*b), + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + if let Some(dot_pos) = value.find('.') { + let operator_str = &value[..dot_pos]; + let val = &value[dot_pos + 1..]; + + let operator = match operator_str { + "eq" => FilterOperator::Equal, + "gt" => FilterOperator::GreaterThan, + "gte" => FilterOperator::GreaterThanOrEqual, + "lt" => FilterOperator::LessThan, + "lte" => FilterOperator::LessThanOrEqual, + _ => { + return Err(de::Error::custom(format!( + "Unknown operator: {}", + operator_str + ))); + } + }; + + Ok(FilterValue { + operator, + value: val.to_string(), + }) + } else { + // Default to equality if no operator specified + Ok(FilterValue { + operator: FilterOperator::Equal, + value: value.to_string(), + }) + } + } } + + deserializer.deserialize_str(FilterValueVisitor) } } +pub const VECTORIZE_SCHEMA: &str = "vectorize"; +static TRIGGER_FN_PREFIX: &str = "vectorize.handle_update_"; fn generate_column_concat(src_columns: &[String], prefix: &str) -> String { src_columns @@ -573,8 +615,9 @@ pub fn hybrid_search_query( let mut bind_value_counter: i16 = 3; let mut where_filter = "WHERE 1=1".to_string(); - for column in filters.keys() { - let filt = format!(" AND t0.\"{column}\" = ${bind_value_counter}"); + for (column, filter_value) in filters.iter() { + let operator = filter_value.operator.to_sql(); + let filt = format!(" AND t0.\"{column}\" {operator} ${bind_value_counter}"); where_filter.push_str(&filt); bind_value_counter += 1; } diff --git a/server/src/routes/search.rs b/server/src/routes/search.rs index 812eeab6..0f8b5ba9 100644 --- a/server/src/routes/search.rs +++ b/server/src/routes/search.rs @@ -8,118 +8,11 @@ use std::sync::Arc; use tokio::sync::RwLock; use utoipa::ToSchema; use uuid::Uuid; -use vectorize_core::query; +use vectorize_core::query::{self, FilterValue}; use vectorize_core::transformers::providers::prepare_generic_embedding_request; use vectorize_core::transformers::types::Inputs; use vectorize_core::types::VectorizeJob; -/// Filter operators supported by the search API -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub enum FilterOperator { - /// Equal to (=) - #[serde(rename = "eq")] - Equal, - /// Greater than (>) - #[serde(rename = "gt")] - GreaterThan, - /// Greater than or equal (>=) - #[serde(rename = "gte")] - GreaterThanOrEqual, - /// Less than (<) - #[serde(rename = "lt")] - LessThan, - /// Less than or equal (<=) - #[serde(rename = "lte")] - LessThanOrEqual, -} - -impl FilterOperator { - /// Convert operator to SQL operator string - pub fn to_sql(&self) -> &'static str { - match self { - FilterOperator::Equal => "=", - FilterOperator::GreaterThan => ">", - FilterOperator::GreaterThanOrEqual => ">=", - FilterOperator::LessThan => "<", - FilterOperator::LessThanOrEqual => "<=", - } - } -} - -/// A filter value with an operator -#[derive(Debug, Clone, Serialize)] -pub struct FilterValue { - pub operator: FilterOperator, - pub value: String, -} - -impl FilterValue { - /// Convert to the original FilterValue format for compatibility with vectorize_core - pub fn to_legacy_filter_value(&self) -> query::FilterValue { - // For now, we'll convert to String variant - this may need to be updated - // based on the actual FilterValue enum structure in vectorize_core - query::FilterValue::String(self.value.clone()) - } -} - -/// Custom deserializer for FilterValue that parses operator.value format -impl<'de> Deserialize<'de> for FilterValue { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - use serde::de::{self, Visitor}; - use std::fmt; - - struct FilterValueVisitor; - - impl<'de> Visitor<'de> for FilterValueVisitor { - type Value = FilterValue; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a string in format 'operator.value' or just 'value'") - } - - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - if let Some(dot_pos) = value.find('.') { - let operator_str = &value[..dot_pos]; - let val = &value[dot_pos + 1..]; - - let operator = match operator_str { - "eq" => FilterOperator::Equal, - "gt" => FilterOperator::GreaterThan, - "gte" => FilterOperator::GreaterThanOrEqual, - "lt" => FilterOperator::LessThan, - "lte" => FilterOperator::LessThanOrEqual, - _ => { - return Err(de::Error::custom(format!( - "Unknown operator: {}", - operator_str - ))); - } - }; - - Ok(FilterValue { - operator, - value: val.to_string(), - }) - } else { - // Default to equality if no operator specified - Ok(FilterValue { - operator: FilterOperator::Equal, - value: value.to_string(), - }) - } - } - } - - deserializer.deserialize_str(FilterValueVisitor) - } -} - #[derive(Serialize, Deserialize, Debug, Clone, ToSchema, FromRow)] pub struct SearchRequest { pub job_name: String, @@ -235,13 +128,6 @@ pub async fn search( let embedding_request = prepare_generic_embedding_request(&vectorizejob.model, &[input]); let embeddings = provider.generate_embedding(&embedding_request).await?; - // Convert our new filters to legacy format for compatibility with vectorize_core - let legacy_filters: BTreeMap = payload - .filters - .iter() - .map(|(key, value)| (key.clone(), value.to_legacy_filter_value())) - .collect(); - let q = query::hybrid_search_query( &payload.job_name, &vectorizejob.src_schema, @@ -253,16 +139,18 @@ pub async fn search( payload.rrf_k, payload.semantic_wt, payload.fts_wt, - &legacy_filters, + &payload.filters, ); + log::warn!("Search query: {}", q); + let mut prepared_query = sqlx::query(&q) .bind(&embeddings.embeddings[0]) .bind(&payload.query); - // Bind filter values using the legacy format - for value in legacy_filters.values() { - prepared_query = value.bind_to_query(prepared_query); + // Bind filter values + for value in payload.filters.values() { + prepared_query = prepared_query.bind(&value.value); } let results = prepared_query.fetch_all(&**pool).await?; From 9fcea12a33c7b6dae5975cb712c53c7df19c8d9a Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Wed, 8 Oct 2025 20:59:12 -0500 Subject: [PATCH 03/15] simple operators --- core/src/query.rs | 64 +++++++++++++++++++++++++++++++++++-- server/src/routes/search.rs | 10 +++--- server/tests/tests.rs | 24 +++++++++++--- server/tests/util.rs | 5 ++- 4 files changed, 91 insertions(+), 12 deletions(-) diff --git a/core/src/query.rs b/core/src/query.rs index 54cdaf58..8b280287 100644 --- a/core/src/query.rs +++ b/core/src/query.rs @@ -40,7 +40,26 @@ impl FilterOperator { #[derive(Debug, Clone, Serialize)] pub struct FilterValue { pub operator: FilterOperator, - pub value: String, + pub value: FilterValueType, +} + +/// The actual value stored in a filter - can be string or numeric +#[derive(Debug, Clone, Serialize)] +pub enum FilterValueType { + String(String), + Integer(i64), + Float(f64), +} + +impl FilterValueType { + /// Get the value as a string for SQL binding + pub fn as_sql_value(&self) -> String { + match self { + FilterValueType::String(s) => s.clone(), + FilterValueType::Integer(i) => i.to_string(), + FilterValueType::Float(f) => f.to_string(), + } + } } /// Custom deserializer for FilterValue that parses operator.value format @@ -83,15 +102,54 @@ impl<'de> serde::Deserialize<'de> for FilterValue { } }; + // Parse the value based on the operator + let parsed_value = match operator { + FilterOperator::Equal => { + // For equality, try to parse as number first, fallback to string + if let Ok(int_val) = val.parse::() { + FilterValueType::Integer(int_val) + } else if let Ok(float_val) = val.parse::() { + FilterValueType::Float(float_val) + } else { + FilterValueType::String(val.to_string()) + } + } + FilterOperator::GreaterThan | + FilterOperator::GreaterThanOrEqual | + FilterOperator::LessThan | + FilterOperator::LessThanOrEqual => { + // For comparison operators, require numeric values + if let Ok(int_val) = val.parse::() { + FilterValueType::Integer(int_val) + } else if let Ok(float_val) = val.parse::() { + FilterValueType::Float(float_val) + } else { + return Err(de::Error::custom(format!( + "Comparison operators (gt, gte, lt, lte) require numeric values, got: '{}'", + val + ))); + } + } + }; + Ok(FilterValue { operator, - value: val.to_string(), + value: parsed_value, }) } else { // Default to equality if no operator specified + // Try to parse as number first, fallback to string + let parsed_value = if let Ok(int_val) = value.parse::() { + FilterValueType::Integer(int_val) + } else if let Ok(float_val) = value.parse::() { + FilterValueType::Float(float_val) + } else { + FilterValueType::String(value.to_string()) + }; + Ok(FilterValue { operator: FilterOperator::Equal, - value: value.to_string(), + value: parsed_value, }) } } diff --git a/server/src/routes/search.rs b/server/src/routes/search.rs index 0f8b5ba9..2fbb9695 100644 --- a/server/src/routes/search.rs +++ b/server/src/routes/search.rs @@ -89,7 +89,7 @@ pub async fn search( for (key, value) in &payload.filters { // validate key and value query::check_input(key)?; - query::check_input(&value.value)?; + query::check_input(&value.value.as_sql_value())?; } } @@ -142,15 +142,17 @@ pub async fn search( &payload.filters, ); - log::warn!("Search query: {}", q); - let mut prepared_query = sqlx::query(&q) .bind(&embeddings.embeddings[0]) .bind(&payload.query); // Bind filter values for value in payload.filters.values() { - prepared_query = prepared_query.bind(&value.value); + prepared_query = match &value.value { + query::FilterValueType::String(s) => prepared_query.bind(s), + query::FilterValueType::Integer(i) => prepared_query.bind(i), + query::FilterValueType::Float(f) => prepared_query.bind(f), + }; } let results = prepared_query.fetch_all(&**pool).await?; diff --git a/server/tests/tests.rs b/server/tests/tests.rs index 811ca553..008f23e4 100644 --- a/server/tests/tests.rs +++ b/server/tests/tests.rs @@ -160,7 +160,7 @@ async fn test_search_filters() { } // filter by price using less than operator - let params = format!("job_name={job_name}&query=electronics&price=lt.30"); + let params = format!("job_name={job_name}&query=electronics&price=eq.25"); let search_results = common::search_with_retry(¶ms, 2).await.unwrap(); assert_eq!(search_results.len(), 2); assert_eq!( @@ -173,9 +173,9 @@ async fn test_search_filters() { ); // test greater than or equal operator - let params = format!("job_name={job_name}&query=electronics&price=gte.25"); - let search_results = common::search_with_retry(¶ms, 2).await.unwrap(); - assert_eq!(search_results.len(), 2); + let params = format!("job_name={job_name}&query=electronics&price=gte.25&limit=5"); + let search_results = common::search_with_retry(¶ms, 5).await.unwrap(); + assert_eq!(search_results.len(), 5); // test backward compatibility - no operator should default to equality let params = format!("job_name={job_name}&query=pen&product_category=electronics",); @@ -245,6 +245,11 @@ async fn test_search_filter_operators() { let search_results = common::search_with_retry(¶ms, 2).await.unwrap(); assert_eq!(search_results.len(), 2); + // Test float values + let params = format!("job_name={job_name}&query=electronics&price=gte.24.5"); + let search_results = common::search_with_retry(¶ms, 2).await.unwrap(); + assert_eq!(search_results.len(), 2); + // Test invalid operator (should return error) let params = format!("job_name={job_name}&query=electronics&price=invalid.25"); let response = client @@ -255,6 +260,17 @@ async fn test_search_filter_operators() { // Should return an error for invalid operator assert!(response.status().is_client_error() || response.status().is_server_error()); + + // Test non-numeric value with comparison operator (should return error) + let params = format!("job_name={job_name}&query=electronics&price=gt.abc"); + let response = client + .get(&format!("http://localhost:8080/api/v1/search?{}", params)) + .send() + .await + .expect("Failed to send request"); + + // Should return an error for non-numeric value with comparison operator + assert!(response.status().is_client_error() || response.status().is_server_error()); } /// proxy is an incomplete feature diff --git a/server/tests/util.rs b/server/tests/util.rs index d2f6cf69..005e660d 100644 --- a/server/tests/util.rs +++ b/server/tests/util.rs @@ -76,7 +76,10 @@ pub mod common { // Check if we've exceeded the timeout if start_time.elapsed() >= timeout_duration { - panic!("Search request timed out after 10 seconds"); + Err(Box::new(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "Search request timed out after 10 seconds", + )))? } // Wait before retrying From c5c3ee13f80dd2ebb9388cf749bc532a76fa3d44 Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Wed, 8 Oct 2025 21:04:04 -0500 Subject: [PATCH 04/15] test --- core/src/query.rs | 8 ++++---- server/tests/tests.rs | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/core/src/query.rs b/core/src/query.rs index 8b280287..1afb9a62 100644 --- a/core/src/query.rs +++ b/core/src/query.rs @@ -114,10 +114,10 @@ impl<'de> serde::Deserialize<'de> for FilterValue { FilterValueType::String(val.to_string()) } } - FilterOperator::GreaterThan | - FilterOperator::GreaterThanOrEqual | - FilterOperator::LessThan | - FilterOperator::LessThanOrEqual => { + FilterOperator::GreaterThan + | FilterOperator::GreaterThanOrEqual + | FilterOperator::LessThan + | FilterOperator::LessThanOrEqual => { // For comparison operators, require numeric values if let Ok(int_val) = val.parse::() { FilterValueType::Integer(int_val) diff --git a/server/tests/tests.rs b/server/tests/tests.rs index 008f23e4..96ced803 100644 --- a/server/tests/tests.rs +++ b/server/tests/tests.rs @@ -159,7 +159,6 @@ async fn test_search_filters() { assert_eq!(result["product_category"].as_str().unwrap(), "electronics"); } - // filter by price using less than operator let params = format!("job_name={job_name}&query=electronics&price=eq.25"); let search_results = common::search_with_retry(¶ms, 2).await.unwrap(); assert_eq!(search_results.len(), 2); From 311822dcf055e05b040cd44b91175273968fa60a Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Wed, 8 Oct 2025 21:10:02 -0500 Subject: [PATCH 05/15] test filter order --- core/src/query.rs | 4 +-- server/tests/tests.rs | 60 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/core/src/query.rs b/core/src/query.rs index 1afb9a62..0a0bfe33 100644 --- a/core/src/query.rs +++ b/core/src/query.rs @@ -7,6 +7,8 @@ use sqlx::postgres::PgRow; use sqlx::{Postgres, Row}; use std::collections::BTreeMap; use tiktoken_rs::cl100k_base; +pub const VECTORIZE_SCHEMA: &str = "vectorize"; +static TRIGGER_FN_PREFIX: &str = "vectorize.handle_update_"; /// Filter operators supported by the search API #[derive(Debug, Clone, PartialEq, Serialize)] @@ -158,8 +160,6 @@ impl<'de> serde::Deserialize<'de> for FilterValue { deserializer.deserialize_str(FilterValueVisitor) } } -pub const VECTORIZE_SCHEMA: &str = "vectorize"; -static TRIGGER_FN_PREFIX: &str = "vectorize.handle_update_"; fn generate_column_concat(src_columns: &[String], prefix: &str) -> String { src_columns diff --git a/server/tests/tests.rs b/server/tests/tests.rs index 96ced803..f442ff51 100644 --- a/server/tests/tests.rs +++ b/server/tests/tests.rs @@ -183,6 +183,66 @@ async fn test_search_filters() { for result in search_results { assert_eq!(result["product_category"].as_str().unwrap(), "electronics"); } + + // test multiple filters - category first, then price + let params = format!( + "job_name={job_name}&query=electronics&product_category=eq.electronics&price=gte.25" + ); + let search_results_category_first = common::search_with_retry(¶ms, 5).await.unwrap(); + assert_eq!(search_results_category_first.len(), 5); + for result in &search_results_category_first { + assert_eq!(result["product_category"].as_str().unwrap(), "electronics"); + assert!(result["price"].as_f64().unwrap() >= 25.0); + } + + // test multiple filters - price first, then category (different order) + let params = format!( + "job_name={job_name}&query=electronics&price=gte.25&product_category=eq.electronics" + ); + let search_results_price_first = common::search_with_retry(¶ms, 5).await.unwrap(); + assert_eq!(search_results_price_first.len(), 5); + for result in &search_results_price_first { + assert_eq!(result["product_category"].as_str().unwrap(), "electronics"); + assert!(result["price"].as_f64().unwrap() >= 25.0); + } + + // verify that both filter orders produce the same results + assert_eq!( + search_results_category_first.len(), + search_results_price_first.len() + ); + // Sort both results by product_id to ensure consistent comparison + let mut category_first_sorted = search_results_category_first.clone(); + let mut price_first_sorted = search_results_price_first.clone(); + category_first_sorted.sort_by(|a, b| { + a["product_id"] + .as_i64() + .unwrap() + .cmp(&b["product_id"].as_i64().unwrap()) + }); + price_first_sorted.sort_by(|a, b| { + a["product_id"] + .as_i64() + .unwrap() + .cmp(&b["product_id"].as_i64().unwrap()) + }); + + for (i, (result1, result2)) in category_first_sorted + .iter() + .zip(price_first_sorted.iter()) + .enumerate() + { + assert_eq!( + result1["product_id"], result2["product_id"], + "Product IDs should match at index {}", + i + ); + assert_eq!( + result1["product_name"], result2["product_name"], + "Product names should match at index {}", + i + ); + } } #[tokio::test] From abad7ed956d483e06fbe9d58c4b89e21c1a4b95c Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Thu, 9 Oct 2025 06:41:38 -0500 Subject: [PATCH 06/15] fix test --- server/tests/tests.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/server/tests/tests.rs b/server/tests/tests.rs index f442ff51..9d0822a5 100644 --- a/server/tests/tests.rs +++ b/server/tests/tests.rs @@ -295,19 +295,19 @@ async fn test_search_filter_operators() { // Test different operators // Greater than - let params = format!("job_name={job_name}&query=electronics&price=gt.20"); - let search_results = common::search_with_retry(¶ms, 3).await.unwrap(); - assert_eq!(search_results.len(), 3); + let params = format!("job_name={job_name}&query=electronics&price=gt.20&limit=100"); + let search_results = common::search_with_retry(¶ms, 14).await.unwrap(); + assert_eq!(search_results.len(), 14); // Less than or equal - let params = format!("job_name={job_name}&query=electronics&price=lte.25"); - let search_results = common::search_with_retry(¶ms, 2).await.unwrap(); - assert_eq!(search_results.len(), 2); + let params = format!("job_name={job_name}&query=electronics&price=lte.25limit=100"); + let search_results = common::search_with_retry(¶ms, 30).await.unwrap(); + assert_eq!(search_results.len(), 30); // Test float values - let params = format!("job_name={job_name}&query=electronics&price=gte.24.5"); - let search_results = common::search_with_retry(¶ms, 2).await.unwrap(); - assert_eq!(search_results.len(), 2); + let params = format!("job_name={job_name}&query=electronics&price=gte.24.5limit=1000"); + let search_results = common::search_with_retry(¶ms, 12).await.unwrap(); + assert_eq!(search_results.len(), 12); // Test invalid operator (should return error) let params = format!("job_name={job_name}&query=electronics&price=invalid.25"); From ec6436d45c0ed7834ff0981b5a6296668f0f703b Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Thu, 9 Oct 2025 06:50:27 -0500 Subject: [PATCH 07/15] handle decimals --- server/src/routes/search.rs | 6 +++--- server/tests/tests.rs | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/server/src/routes/search.rs b/server/src/routes/search.rs index 2fbb9695..076fd764 100644 --- a/server/src/routes/search.rs +++ b/server/src/routes/search.rs @@ -86,10 +86,10 @@ pub async fn search( // check inputs and filters are valid if they exist and create a SQL string for them query::check_input(&payload.job_name)?; if !payload.filters.is_empty() { - for (key, value) in &payload.filters { - // validate key and value + for key in payload.filters.keys() { + // validate key only (column names should be alphanumeric + underscore) query::check_input(key)?; - query::check_input(&value.value.as_sql_value())?; + // Note: filter values are validated during deserialization in FilterValue } } diff --git a/server/tests/tests.rs b/server/tests/tests.rs index 9d0822a5..d522b0ad 100644 --- a/server/tests/tests.rs +++ b/server/tests/tests.rs @@ -300,12 +300,12 @@ async fn test_search_filter_operators() { assert_eq!(search_results.len(), 14); // Less than or equal - let params = format!("job_name={job_name}&query=electronics&price=lte.25limit=100"); + let params = format!("job_name={job_name}&query=electronics&price=lte.25&limit=100"); let search_results = common::search_with_retry(¶ms, 30).await.unwrap(); assert_eq!(search_results.len(), 30); // Test float values - let params = format!("job_name={job_name}&query=electronics&price=gte.24.5limit=1000"); + let params = format!("job_name={job_name}&query=electronics&price=gte.24.5&limit=1000"); let search_results = common::search_with_retry(¶ms, 12).await.unwrap(); assert_eq!(search_results.len(), 12); From 72d968919161eee5f0bc4641e4bbbd50e4a4576d Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Thu, 9 Oct 2025 07:25:27 -0500 Subject: [PATCH 08/15] add tests --- core/src/query.rs | 783 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 777 insertions(+), 6 deletions(-) diff --git a/core/src/query.rs b/core/src/query.rs index 0a0bfe33..9c1d6755 100644 --- a/core/src/query.rs +++ b/core/src/query.rs @@ -45,12 +45,13 @@ pub struct FilterValue { pub value: FilterValueType, } -/// The actual value stored in a filter - can be string or numeric -#[derive(Debug, Clone, Serialize)] +/// The actual value stored in a filter +#[derive(Debug, Clone, PartialEq, Serialize)] pub enum FilterValueType { String(String), Integer(i64), Float(f64), + Boolean(bool), } impl FilterValueType { @@ -60,6 +61,7 @@ impl FilterValueType { FilterValueType::String(s) => s.clone(), FilterValueType::Integer(i) => i.to_string(), FilterValueType::Float(f) => f.to_string(), + FilterValueType::Boolean(b) => b.to_string(), } } } @@ -107,8 +109,10 @@ impl<'de> serde::Deserialize<'de> for FilterValue { // Parse the value based on the operator let parsed_value = match operator { FilterOperator::Equal => { - // For equality, try to parse as number first, fallback to string - if let Ok(int_val) = val.parse::() { + // For equality, try to parse as boolean first, then number, fallback to string + if let Ok(bool_val) = val.parse::() { + FilterValueType::Boolean(bool_val) + } else if let Ok(int_val) = val.parse::() { FilterValueType::Integer(int_val) } else if let Ok(float_val) = val.parse::() { FilterValueType::Float(float_val) @@ -140,8 +144,10 @@ impl<'de> serde::Deserialize<'de> for FilterValue { }) } else { // Default to equality if no operator specified - // Try to parse as number first, fallback to string - let parsed_value = if let Ok(int_val) = value.parse::() { + // Try to parse as boolean first, then number, fallback to string + let parsed_value = if let Ok(bool_val) = value.parse::() { + FilterValueType::Boolean(bool_val) + } else if let Ok(int_val) = value.parse::() { FilterValueType::Integer(int_val) } else if let Ok(float_val) = value.parse::() { FilterValueType::Float(float_val) @@ -744,6 +750,7 @@ pub fn hybrid_search_query( #[cfg(test)] mod tests { use super::*; + use serde_json; #[test] fn test_create_update_trigger_single() { @@ -776,4 +783,768 @@ EXECUTE FUNCTION vectorize.handle_update_another_job();" let result = create_event_trigger(job_name, "myschema", table_name, "INSERT"); assert_eq!(expected, result); } + + // ===== FilterValue Deserialization Tests ===== + + #[test] + fn test_filter_value_deserialize_equality_string() { + let json = "\"eq.hello\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "hello"); + } + + #[test] + fn test_filter_value_deserialize_equality_integer() { + let json = "\"eq.42\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "42"); + } + + #[test] + fn test_filter_value_deserialize_equality_float() { + let json = "\"eq.3.14\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "3.14"); + } + + #[test] + fn test_filter_value_deserialize_greater_than() { + let json = "\"gt.100\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::GreaterThan); + assert_eq!(filter.value.as_sql_value(), "100"); + } + + #[test] + fn test_filter_value_deserialize_greater_than_or_equal() { + let json = "\"gte.50.5\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::GreaterThanOrEqual); + assert_eq!(filter.value.as_sql_value(), "50.5"); + } + + #[test] + fn test_filter_value_deserialize_less_than() { + let json = "\"lt.25\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::LessThan); + assert_eq!(filter.value.as_sql_value(), "25"); + } + + #[test] + fn test_filter_value_deserialize_less_than_or_equal() { + let json = "\"lte.10.0\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::LessThanOrEqual); + assert_eq!(filter.value.as_sql_value(), "10"); + } + + #[test] + fn test_filter_value_deserialize_default_equality() { + let json = "\"hello\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "hello"); + } + + #[test] + fn test_filter_value_deserialize_default_equality_numeric() { + let json = "\"42\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "42"); + } + + // ===== Edge Case Tests ===== + + #[test] + fn test_filter_value_deserialize_empty_string() { + let json = "\"eq.\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), ""); + } + + #[test] + fn test_filter_value_deserialize_zero_values() { + let json_int = "\"eq.0\""; + let filter_int: FilterValue = serde_json::from_str(json_int).unwrap(); + assert_eq!(filter_int.operator, FilterOperator::Equal); + assert_eq!(filter_int.value.as_sql_value(), "0"); + + let json_float = "\"eq.0.0\""; + let filter_float: FilterValue = serde_json::from_str(json_float).unwrap(); + assert_eq!(filter_float.operator, FilterOperator::Equal); + assert_eq!(filter_float.value.as_sql_value(), "0"); + } + + #[test] + fn test_filter_value_deserialize_negative_values() { + let json_int = "\"eq.-42\""; + let filter_int: FilterValue = serde_json::from_str(json_int).unwrap(); + assert_eq!(filter_int.operator, FilterOperator::Equal); + assert_eq!(filter_int.value.as_sql_value(), "-42"); + + let json_float = "\"eq.-3.14\""; + let filter_float: FilterValue = serde_json::from_str(json_float).unwrap(); + assert_eq!(filter_float.operator, FilterOperator::Equal); + assert_eq!(filter_float.value.as_sql_value(), "-3.14"); + } + + #[test] + fn test_filter_value_deserialize_special_characters() { + let json = "\"eq.hello-world_123\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "hello-world_123"); + } + + #[test] + fn test_filter_value_deserialize_unicode_characters() { + let json = "\"eq.测试\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "测试"); + } + + #[test] + fn test_filter_value_deserialize_whitespace_values() { + let json = "\"eq. hello \""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), " hello "); + } + + #[test] + fn test_filter_value_deserialize_scientific_notation() { + let json = "\"eq.1e5\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "100000"); + } + + #[test] + fn test_filter_value_deserialize_large_numbers() { + let json = "\"eq.9223372036854775807\""; // i64::MAX + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "9223372036854775807"); + } + + #[test] + fn test_filter_value_deserialize_precision_float() { + let json = "\"eq.3.141592653589793\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "3.141592653589793"); + } + + // ===== SQL Injection Attack Tests ===== + + #[test] + fn test_filter_value_deserialize_sql_injection_basic() { + // Test basic SQL injection attempts + let malicious_inputs = vec![ + "'; DROP TABLE users; --", + "' OR '1'='1", + "' UNION SELECT * FROM users --", + "'; INSERT INTO users VALUES ('hacker', 'password'); --", + "' OR 1=1 --", + ]; + + for malicious_input in malicious_inputs { + let json = format!("\"eq.{}\"", malicious_input); + let filter: FilterValue = serde_json::from_str(&json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), malicious_input); + } + } + + #[test] + fn test_filter_value_deserialize_sql_injection_with_operators() { + // Test SQL injection with different operators + let malicious_inputs = vec![ + ("gt", "1; DROP TABLE users; --"), + ("gte", "0 OR 1=1 --"), + ( + "lt", + "999; INSERT INTO users VALUES ('hacker', 'password'); --", + ), + ("lte", "100 UNION SELECT * FROM users --"), + ]; + + for (op, malicious_input) in malicious_inputs { + let json = format!("\"{}.{}\"", op, malicious_input); + let result: Result = serde_json::from_str(&json); + // These should fail because comparison operators require numeric values + assert!( + result.is_err(), + "Should fail for non-numeric value with comparison operator: {}", + malicious_input + ); + } + } + + #[test] + fn test_filter_value_deserialize_sql_injection_script_tags() { + // Test XSS-style attacks that might be used in SQL injection + let malicious_inputs = vec![ + "", + "javascript:alert('xss')", + "'; ; --", + "' OR ''='", + ]; + + for malicious_input in malicious_inputs { + let json = format!("\"eq.{}\"", malicious_input); + let filter: FilterValue = serde_json::from_str(&json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), malicious_input); + } + } + + #[test] + fn test_filter_value_deserialize_sql_injection_encoding_attempts() { + // Test various encoding attempts to bypass filters + let malicious_inputs = vec![ + "%27%20OR%201%3D1%20--", // URL encoded + "' OR 1=1 --", // HTML entity encoded + "' OR CHAR(49)=CHAR(49) --", // CHAR function + "' OR ASCII('a')=97 --", // ASCII function + ]; + + for malicious_input in malicious_inputs { + let json = format!("\"eq.{}\"", malicious_input); + let filter: FilterValue = serde_json::from_str(&json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), malicious_input); + } + } + + #[test] + fn test_filter_value_deserialize_sql_injection_time_based() { + // Test time-based SQL injection attempts + let malicious_inputs = vec![ + "'; WAITFOR DELAY '00:00:05' --", + "' OR SLEEP(5) --", + "'; SELECT pg_sleep(5); --", + "' OR (SELECT COUNT(*) FROM users WHERE username='admin' AND SLEEP(5))>0 --", + ]; + + for malicious_input in malicious_inputs { + let json = format!("\"eq.{}\"", malicious_input); + let filter: FilterValue = serde_json::from_str(&json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), malicious_input); + } + } + + #[test] + fn test_filter_value_deserialize_sql_injection_blind() { + // Test blind SQL injection attempts + let malicious_inputs = vec![ + "' AND (SELECT SUBSTRING(password,1,1) FROM users WHERE username='admin')='a' --", + "' OR (SELECT COUNT(*) FROM users WHERE username='admin' AND password LIKE 'a%')>0 --", + "' AND EXISTS(SELECT * FROM users WHERE username='admin' AND password LIKE 'a%') --", + "' OR (SELECT ASCII(SUBSTRING(password,1,1)) FROM users WHERE username='admin')=97 --", + ]; + + for malicious_input in malicious_inputs { + let json = format!("\"eq.{}\"", malicious_input); + let filter: FilterValue = serde_json::from_str(&json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), malicious_input); + } + } + + #[test] + fn test_filter_value_deserialize_sql_injection_union() { + // Test UNION-based SQL injection attempts + let malicious_inputs = vec![ + "' UNION SELECT username, password FROM users --", + "' UNION ALL SELECT NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL --", + "' UNION SELECT 1,2,3,4,5,6,7,8,9,10 --", + "' UNION SELECT table_name FROM information_schema.tables --", + ]; + + for malicious_input in malicious_inputs { + let json = format!("\"eq.{}\"", malicious_input); + let filter: FilterValue = serde_json::from_str(&json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), malicious_input); + } + } + + #[test] + fn test_filter_value_deserialize_sql_injection_error_based() { + // Test error-based SQL injection attempts + let malicious_inputs = vec![ + "' AND (SELECT * FROM (SELECT COUNT(*),CONCAT(version(),FLOOR(RAND(0)*2))x FROM information_schema.tables GROUP BY x)a) --", + "' AND EXTRACTVALUE(1, CONCAT(0x7e, (SELECT version()), 0x7e)) --", + "' AND (SELECT * FROM (SELECT COUNT(*),CONCAT(CAST((SELECT version()) AS CHAR),0x7e,FLOOR(RAND(0)*2))x FROM information_schema.tables GROUP BY x)a) --", + ]; + + for malicious_input in malicious_inputs { + let json = format!("\"eq.{}\"", malicious_input); + let filter: FilterValue = serde_json::from_str(&json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), malicious_input); + } + } + + // ===== Error Handling Tests ===== + + #[test] + fn test_filter_value_deserialize_invalid_operator() { + let json = "\"invalid.42\""; + let result: Result = serde_json::from_str(json); + assert!(result.is_err(), "Should fail for invalid operator"); + let error = result.unwrap_err(); + assert!(error.to_string().contains("Unknown operator")); + } + + #[test] + fn test_filter_value_deserialize_comparison_with_string() { + // Test that comparison operators fail with non-numeric values + let test_cases = vec![ + ("gt", "hello"), + ("gte", "world"), + ("lt", "test"), + ("lte", "string"), + ]; + + for (op, value) in test_cases { + let json = format!("\"{}.{}\"", op, value); + let result: Result = serde_json::from_str(&json); + assert!( + result.is_err(), + "Should fail for non-numeric value with {} operator", + op + ); + let error = result.unwrap_err(); + assert!(error.to_string().contains("require numeric values")); + } + } + + #[test] + fn test_filter_value_deserialize_malformed_json() { + // Test various malformed JSON inputs + let malformed_inputs = vec![ + ("\"eq.42", false), // Missing closing quote + ("eq.42\"", false), // Missing opening quote + ("\"eq.42\"", true), // This should work + ("\"eq.42.extra\"", true), // Extra dot should work as string + ("\"eq.\"", true), // Empty value should work + ("\".42\"", false), // Missing operator should fail + ("\"eq\"", true), // Missing dot and value should work as string + ]; + + for (input, should_succeed) in malformed_inputs { + let result: Result = serde_json::from_str(input); + if should_succeed { + assert!(result.is_ok(), "Should succeed for input: {}", input); + } else { + assert!( + result.is_err(), + "Should fail for malformed input: {}", + input + ); + } + } + } + + #[test] + fn test_filter_value_deserialize_empty_input() { + let json = "\"\""; + let result: Result = serde_json::from_str(json); + // Empty input should succeed and default to equality with empty string + assert!(result.is_ok(), "Empty input should succeed"); + let filter = result.unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), ""); + } + + #[test] + fn test_filter_value_deserialize_just_dot() { + let json = "\".\""; + let result: Result = serde_json::from_str(json); + assert!(result.is_err(), "Should fail for input with just a dot"); + } + + #[test] + fn test_filter_value_deserialize_multiple_dots() { + let json = "\"eq.42.extra\""; + let result: Result = serde_json::from_str(json); + // Multiple dots should succeed and treat the whole thing as a string + assert!( + result.is_ok(), + "Should succeed for input with multiple dots" + ); + let filter = result.unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "42.extra"); + } + + #[test] + fn test_filter_value_deserialize_case_sensitive_operators() { + // Test that operators are case sensitive + let case_variations = vec!["EQ.42", "GT.42", "GTE.42", "LT.42", "LTE.42"]; + + for input in case_variations { + let json = format!("\"{}\"", input); + let result: Result = serde_json::from_str(&json); + assert!( + result.is_err(), + "Should fail for uppercase operator: {}", + input + ); + } + } + + #[test] + fn test_filter_value_deserialize_whitespace_in_operator() { + // Test operators with whitespace - these actually succeed as they're treated as strings + let whitespace_inputs = vec![ + ("\" eq.42\"", false), // Leading space - fails + ("\"eq .42\"", false), // Space before dot - fails + ("\"eq. 42\"", true), // Space after dot - succeeds as string + ("\"eq.42 \"", true), // Trailing space - succeeds as string + ]; + + for (input, should_succeed) in whitespace_inputs { + let result: Result = serde_json::from_str(input); + if should_succeed { + assert!(result.is_ok(), "Should succeed for input: {}", input); + } else { + assert!(result.is_err(), "Should fail for input: {}", input); + } + } + } + + // ===== Numeric Parsing Edge Cases ===== + + #[test] + fn test_filter_value_deserialize_numeric_boundaries() { + // Test integer boundaries + let json_max_i64 = "\"eq.9223372036854775807\""; // i64::MAX + let filter_max: FilterValue = serde_json::from_str(json_max_i64).unwrap(); + assert_eq!(filter_max.value.as_sql_value(), "9223372036854775807"); + + let json_min_i64 = "\"eq.-9223372036854775808\""; // i64::MIN + let filter_min: FilterValue = serde_json::from_str(json_min_i64).unwrap(); + assert_eq!(filter_min.value.as_sql_value(), "-9223372036854775808"); + } + + #[test] + fn test_filter_value_deserialize_float_precision() { + // Test various float precision cases + let test_cases = vec![ + ("0.0", "0"), + ("0.1", "0.1"), + ("0.01", "0.01"), + ("0.001", "0.001"), + ("1.0", "1"), + ("1.1", "1.1"), + ("1.11", "1.11"), + ("1.111", "1.111"), + ]; + + for (input, expected) in test_cases { + let json = format!("\"eq.{}\"", input); + let filter: FilterValue = serde_json::from_str(&json).unwrap(); + assert_eq!( + filter.value.as_sql_value(), + expected, + "Failed for input: {}", + input + ); + } + } + + #[test] + fn test_filter_value_deserialize_scientific_notation_edge_cases() { + // Test scientific notation edge cases + let test_cases = vec![ + ("1e0", "1"), + ("1e1", "10"), + ("1e-1", "0.1"), + ("1e-10", "0.0000000001"), + ("1e10", "10000000000"), + ("1.5e2", "150"), + ("1.5e-2", "0.015"), + ]; + + for (input, expected) in test_cases { + let json = format!("\"eq.{}\"", input); + let filter: FilterValue = serde_json::from_str(&json).unwrap(); + assert_eq!( + filter.value.as_sql_value(), + expected, + "Failed for input: {}", + input + ); + } + } + + #[test] + fn test_filter_value_deserialize_hex_numbers() { + // Test that hex numbers are treated as strings (not parsed as integers) + let json = "\"eq.0xFF\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.value.as_sql_value(), "0xFF"); + } + + #[test] + fn test_filter_value_deserialize_octal_numbers() { + // Test that octal numbers are parsed as integers + let json = "\"eq.0777\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.value.as_sql_value(), "777"); + } + + #[test] + fn test_filter_value_deserialize_binary_numbers() { + // Test that binary numbers are treated as strings + let json = "\"eq.0b1010\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.value.as_sql_value(), "0b1010"); + } + + #[test] + fn test_filter_value_deserialize_numeric_with_leading_zeros() { + // Test numbers with leading zeros + let json = "\"eq.007\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.value.as_sql_value(), "7"); + } + + #[test] + fn test_filter_value_deserialize_numeric_with_trailing_zeros() { + // Test numbers with trailing zeros + let json = "\"eq.42.000\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.value.as_sql_value(), "42"); + } + + #[test] + fn test_filter_value_deserialize_numeric_with_plus_sign() { + // Test numbers with explicit plus sign + let json = "\"+42\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.value.as_sql_value(), "42"); + } + + #[test] + fn test_filter_value_deserialize_numeric_with_plus_sign_float() { + // Test floats with explicit plus sign (should work as default equality) + let json = "\"+3.14\""; + let result: Result = serde_json::from_str(json); + // This should fail because "+3" is not a valid operator + assert!( + result.is_err(), + "Should fail for input with plus sign as operator" + ); + } + + #[test] + fn test_filter_value_deserialize_numeric_infinity() { + // Test infinity values (should be parsed as float infinity) + let json = "\"eq.infinity\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.value.as_sql_value(), "inf"); + } + + #[test] + fn test_filter_value_deserialize_numeric_nan() { + // Test NaN values (should be parsed as float NaN) + let json = "\"eq.nan\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + // NaN comparison requires special handling + match filter.value { + FilterValueType::Float(f) => assert!(f.is_nan(), "Expected NaN"), + _ => panic!("Expected Float(NaN)"), + } + assert_eq!(filter.value.as_sql_value(), "NaN"); + } + + #[test] + fn test_filter_value_deserialize_numeric_very_small() { + // Test very small numbers + let json = "\"eq.0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + // Very small numbers get parsed as floats and converted to "0" when using to_string() + assert_eq!(filter.value.as_sql_value(), "0"); + } + + #[test] + fn test_filter_value_deserialize_numeric_very_large() { + // Test very large numbers (using f64::MAX as a reasonable upper bound) + let json = "\"eq.1.7976931348623157e308\""; // f64::MAX + let filter: FilterValue = serde_json::from_str(json).unwrap(); + // Very large numbers get parsed as floats and converted to a long decimal string when using to_string() + assert_eq!( + filter.value.as_sql_value(), + "179769313486231570000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + ); + } + + // ===== Boolean Filter Value Tests ===== + + #[test] + fn test_filter_value_deserialize_boolean_true() { + let json = "\"eq.true\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value, FilterValueType::Boolean(true)); + assert_eq!(filter.value.as_sql_value(), "true"); + } + + #[test] + fn test_filter_value_deserialize_boolean_false() { + let json = "\"eq.false\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value, FilterValueType::Boolean(false)); + assert_eq!(filter.value.as_sql_value(), "false"); + } + + #[test] + fn test_filter_value_deserialize_boolean_default_true() { + let json = "\"true\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value, FilterValueType::Boolean(true)); + assert_eq!(filter.value.as_sql_value(), "true"); + } + + #[test] + fn test_filter_value_deserialize_boolean_default_false() { + let json = "\"false\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value, FilterValueType::Boolean(false)); + assert_eq!(filter.value.as_sql_value(), "false"); + } + + #[test] + fn test_filter_value_deserialize_boolean_case_sensitive() { + // Test that boolean parsing is case sensitive - uppercase values become strings + let test_cases = vec![ + ("\"eq.True\"", FilterValueType::String("True".to_string())), + ("\"eq.False\"", FilterValueType::String("False".to_string())), + ("\"eq.TRUE\"", FilterValueType::String("TRUE".to_string())), + ("\"eq.FALSE\"", FilterValueType::String("FALSE".to_string())), + ("\"eq.true\"", FilterValueType::Boolean(true)), + ("\"eq.false\"", FilterValueType::Boolean(false)), + ]; + + for (input, expected_value) in test_cases { + let filter: FilterValue = serde_json::from_str(input).unwrap(); + assert_eq!(filter.value, expected_value); + } + } + + #[test] + fn test_filter_value_deserialize_boolean_with_whitespace() { + // Test boolean parsing with whitespace + let test_cases = vec![ + ("\"eq. true\"", true), // Space before true - should succeed as string + ("\"eq.false \"", true), // Space after false - should succeed as string + ("\"eq. true \"", true), // Spaces around true - should succeed as string + ]; + + for (input, should_succeed) in test_cases { + let result: Result = serde_json::from_str(input); + if should_succeed { + assert!(result.is_ok(), "Should succeed for input: {}", input); + let filter = result.unwrap(); + // With whitespace, it should be parsed as a string, not boolean + assert!(matches!(filter.value, FilterValueType::String(_))); + } else { + assert!(result.is_err(), "Should fail for input: {}", input); + } + } + } + + #[test] + fn test_filter_value_deserialize_boolean_vs_string() { + // Test that "true" and "false" strings are not parsed as booleans + let test_cases = vec![ + ( + "\"eq.true_string\"", + FilterValueType::String("true_string".to_string()), + ), + ( + "\"eq.false_string\"", + FilterValueType::String("false_string".to_string()), + ), + ( + "\"eq.true123\"", + FilterValueType::String("true123".to_string()), + ), + ( + "\"eq.false456\"", + FilterValueType::String("false456".to_string()), + ), + ]; + + for (input, expected_value) in test_cases { + let filter: FilterValue = serde_json::from_str(input).unwrap(); + assert_eq!(filter.value, expected_value); + } + } + + #[test] + fn test_filter_value_deserialize_boolean_vs_numeric() { + // Test that numeric values are still parsed as numbers, not booleans + let test_cases = vec![ + ("\"eq.1\"", FilterValueType::Integer(1)), + ("\"eq.0\"", FilterValueType::Integer(0)), + ("\"eq.1.0\"", FilterValueType::Float(1.0)), + ("\"eq.0.0\"", FilterValueType::Float(0.0)), + ]; + + for (input, expected_value) in test_cases { + let filter: FilterValue = serde_json::from_str(input).unwrap(); + assert_eq!(filter.value, expected_value); + } + } + + #[test] + fn test_filter_value_deserialize_boolean_comparison_operators() { + // Test that comparison operators with boolean values fail (as they should require numeric values) + let test_cases = vec!["gt.true", "gte.false", "lt.true", "lte.false"]; + + for input in test_cases { + let json = format!("\"{}\"", input); + let result: Result = serde_json::from_str(&json); + assert!( + result.is_err(), + "Should fail for boolean value with comparison operator: {}", + input + ); + let error = result.unwrap_err(); + assert!(error.to_string().contains("require numeric values")); + } + } + + #[test] + fn test_filter_value_deserialize_boolean_edge_cases() { + // Test edge cases for boolean parsing + let test_cases = vec![ + ("\"eq.true\"", FilterValueType::Boolean(true)), + ("\"eq.false\"", FilterValueType::Boolean(false)), + ("\"true\"", FilterValueType::Boolean(true)), + ("\"false\"", FilterValueType::Boolean(false)), + ]; + + for (input, expected_value) in test_cases { + let filter: FilterValue = serde_json::from_str(input).unwrap(); + assert_eq!(filter.value, expected_value); + assert_eq!(filter.operator, FilterOperator::Equal); + } + } } From 19febb2b27830b0a2516336a4b5f2bdac84a1388 Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Thu, 9 Oct 2025 07:41:42 -0500 Subject: [PATCH 09/15] add more test cases --- core/src/query.rs | 224 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 196 insertions(+), 28 deletions(-) diff --git a/core/src/query.rs b/core/src/query.rs index 9c1d6755..fa06b5e9 100644 --- a/core/src/query.rs +++ b/core/src/query.rs @@ -117,6 +117,9 @@ impl<'de> serde::Deserialize<'de> for FilterValue { } else if let Ok(float_val) = val.parse::() { FilterValueType::Float(float_val) } else { + // Validate string values for SQL injection + validate_filter_value(val) + .map_err(|e| de::Error::custom(e.to_string()))?; FilterValueType::String(val.to_string()) } } @@ -152,6 +155,9 @@ impl<'de> serde::Deserialize<'de> for FilterValue { } else if let Ok(float_val) = value.parse::() { FilterValueType::Float(float_val) } else { + // Validate string values for SQL injection + validate_filter_value(value) + .map_err(|e| de::Error::custom(e.to_string()))?; FilterValueType::String(value.to_string()) }; @@ -188,6 +194,105 @@ pub fn check_input(input: &str) -> Result<()> { } } +/// Validates filter values against SQL injection patterns +/// Returns an error if malicious patterns are detected +pub fn validate_filter_value(value: &str) -> Result<()> { + let value_lower = value.to_lowercase(); + + // Check for common SQL injection patterns + let malicious_patterns = [ + // Basic SQL injection patterns + "';", + "' or", + "' union", + "' and", + "' drop", + "' delete", + "' insert", + "' update", + "' create", + "' alter", + "' exec", + "' execute", + "' script", + "javascript:", + "") || value.contains("<")) { + return Err(anyhow!( + "Potentially malicious input detected: '{}' contains suspicious character combination", + value + )); + } + + // Check for excessive special characters that might indicate encoding attempts + let special_char_count = value + .chars() + .filter(|c| !c.is_alphanumeric() && *c != ' ' && *c != '.' && *c != '-' && *c != '_') + .count(); + if special_char_count > value.len() / 4 { + return Err(anyhow!( + "Potentially malicious input detected: '{}' contains excessive special characters", + value + )); + } + + Ok(()) +} + pub fn create_vectorize_table() -> String { "CREATE TABLE IF NOT EXISTS vectorize.job ( @@ -946,7 +1051,7 @@ EXECUTE FUNCTION vectorize.handle_update_another_job();" #[test] fn test_filter_value_deserialize_sql_injection_basic() { - // Test basic SQL injection attempts + // Test basic SQL injection attempts - these should now be rejected let malicious_inputs = vec![ "'; DROP TABLE users; --", "' OR '1'='1", @@ -957,9 +1062,18 @@ EXECUTE FUNCTION vectorize.handle_update_another_job();" for malicious_input in malicious_inputs { let json = format!("\"eq.{}\"", malicious_input); - let filter: FilterValue = serde_json::from_str(&json).unwrap(); - assert_eq!(filter.operator, FilterOperator::Equal); - assert_eq!(filter.value.as_sql_value(), malicious_input); + let result: Result = serde_json::from_str(&json); + assert!( + result.is_err(), + "Should reject malicious input: {}", + malicious_input + ); + let error = result.unwrap_err(); + assert!( + error + .to_string() + .contains("Potentially malicious input detected") + ); } } @@ -990,7 +1104,7 @@ EXECUTE FUNCTION vectorize.handle_update_another_job();" #[test] fn test_filter_value_deserialize_sql_injection_script_tags() { - // Test XSS-style attacks that might be used in SQL injection + // Test XSS-style attacks that might be used in SQL injection - these should now be rejected let malicious_inputs = vec![ "", "javascript:alert('xss')", @@ -1000,15 +1114,24 @@ EXECUTE FUNCTION vectorize.handle_update_another_job();" for malicious_input in malicious_inputs { let json = format!("\"eq.{}\"", malicious_input); - let filter: FilterValue = serde_json::from_str(&json).unwrap(); - assert_eq!(filter.operator, FilterOperator::Equal); - assert_eq!(filter.value.as_sql_value(), malicious_input); + let result: Result = serde_json::from_str(&json); + assert!( + result.is_err(), + "Should reject malicious input: {}", + malicious_input + ); + let error = result.unwrap_err(); + assert!( + error + .to_string() + .contains("Potentially malicious input detected") + ); } } #[test] fn test_filter_value_deserialize_sql_injection_encoding_attempts() { - // Test various encoding attempts to bypass filters + // Test various encoding attempts to bypass filters - these should now be rejected let malicious_inputs = vec![ "%27%20OR%201%3D1%20--", // URL encoded "' OR 1=1 --", // HTML entity encoded @@ -1018,15 +1141,24 @@ EXECUTE FUNCTION vectorize.handle_update_another_job();" for malicious_input in malicious_inputs { let json = format!("\"eq.{}\"", malicious_input); - let filter: FilterValue = serde_json::from_str(&json).unwrap(); - assert_eq!(filter.operator, FilterOperator::Equal); - assert_eq!(filter.value.as_sql_value(), malicious_input); + let result: Result = serde_json::from_str(&json); + assert!( + result.is_err(), + "Should reject malicious input: {}", + malicious_input + ); + let error = result.unwrap_err(); + assert!( + error + .to_string() + .contains("Potentially malicious input detected") + ); } } #[test] fn test_filter_value_deserialize_sql_injection_time_based() { - // Test time-based SQL injection attempts + // Test time-based SQL injection attempts - these should now be rejected let malicious_inputs = vec![ "'; WAITFOR DELAY '00:00:05' --", "' OR SLEEP(5) --", @@ -1036,15 +1168,24 @@ EXECUTE FUNCTION vectorize.handle_update_another_job();" for malicious_input in malicious_inputs { let json = format!("\"eq.{}\"", malicious_input); - let filter: FilterValue = serde_json::from_str(&json).unwrap(); - assert_eq!(filter.operator, FilterOperator::Equal); - assert_eq!(filter.value.as_sql_value(), malicious_input); + let result: Result = serde_json::from_str(&json); + assert!( + result.is_err(), + "Should reject malicious input: {}", + malicious_input + ); + let error = result.unwrap_err(); + assert!( + error + .to_string() + .contains("Potentially malicious input detected") + ); } } #[test] fn test_filter_value_deserialize_sql_injection_blind() { - // Test blind SQL injection attempts + // Test blind SQL injection attempts - these should now be rejected let malicious_inputs = vec![ "' AND (SELECT SUBSTRING(password,1,1) FROM users WHERE username='admin')='a' --", "' OR (SELECT COUNT(*) FROM users WHERE username='admin' AND password LIKE 'a%')>0 --", @@ -1054,15 +1195,24 @@ EXECUTE FUNCTION vectorize.handle_update_another_job();" for malicious_input in malicious_inputs { let json = format!("\"eq.{}\"", malicious_input); - let filter: FilterValue = serde_json::from_str(&json).unwrap(); - assert_eq!(filter.operator, FilterOperator::Equal); - assert_eq!(filter.value.as_sql_value(), malicious_input); + let result: Result = serde_json::from_str(&json); + assert!( + result.is_err(), + "Should reject malicious input: {}", + malicious_input + ); + let error = result.unwrap_err(); + assert!( + error + .to_string() + .contains("Potentially malicious input detected") + ); } } #[test] fn test_filter_value_deserialize_sql_injection_union() { - // Test UNION-based SQL injection attempts + // Test UNION-based SQL injection attempts - these should now be rejected let malicious_inputs = vec![ "' UNION SELECT username, password FROM users --", "' UNION ALL SELECT NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL --", @@ -1072,15 +1222,24 @@ EXECUTE FUNCTION vectorize.handle_update_another_job();" for malicious_input in malicious_inputs { let json = format!("\"eq.{}\"", malicious_input); - let filter: FilterValue = serde_json::from_str(&json).unwrap(); - assert_eq!(filter.operator, FilterOperator::Equal); - assert_eq!(filter.value.as_sql_value(), malicious_input); + let result: Result = serde_json::from_str(&json); + assert!( + result.is_err(), + "Should reject malicious input: {}", + malicious_input + ); + let error = result.unwrap_err(); + assert!( + error + .to_string() + .contains("Potentially malicious input detected") + ); } } #[test] fn test_filter_value_deserialize_sql_injection_error_based() { - // Test error-based SQL injection attempts + // Test error-based SQL injection attempts - these should now be rejected let malicious_inputs = vec![ "' AND (SELECT * FROM (SELECT COUNT(*),CONCAT(version(),FLOOR(RAND(0)*2))x FROM information_schema.tables GROUP BY x)a) --", "' AND EXTRACTVALUE(1, CONCAT(0x7e, (SELECT version()), 0x7e)) --", @@ -1089,9 +1248,18 @@ EXECUTE FUNCTION vectorize.handle_update_another_job();" for malicious_input in malicious_inputs { let json = format!("\"eq.{}\"", malicious_input); - let filter: FilterValue = serde_json::from_str(&json).unwrap(); - assert_eq!(filter.operator, FilterOperator::Equal); - assert_eq!(filter.value.as_sql_value(), malicious_input); + let result: Result = serde_json::from_str(&json); + assert!( + result.is_err(), + "Should reject malicious input: {}", + malicious_input + ); + let error = result.unwrap_err(); + assert!( + error + .to_string() + .contains("Potentially malicious input detected") + ); } } From 8cae567268c3978d1d61a545e7d0ce2830b16688 Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Thu, 9 Oct 2025 07:51:33 -0500 Subject: [PATCH 10/15] add bool support --- server/src/routes/search.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/server/src/routes/search.rs b/server/src/routes/search.rs index 076fd764..18c9ab4e 100644 --- a/server/src/routes/search.rs +++ b/server/src/routes/search.rs @@ -152,6 +152,7 @@ pub async fn search( query::FilterValueType::String(s) => prepared_query.bind(s), query::FilterValueType::Integer(i) => prepared_query.bind(i), query::FilterValueType::Float(f) => prepared_query.bind(f), + query::FilterValueType::Boolean(b) => prepared_query.bind(b), }; } From b3bdd85cd6d281a34a3b00e2dc5790b10f0e4486 Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Thu, 9 Oct 2025 09:49:17 -0500 Subject: [PATCH 11/15] debug error --- server/tests/tests.rs | 4 ++-- server/tests/util.rs | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/server/tests/tests.rs b/server/tests/tests.rs index d522b0ad..c06f7212 100644 --- a/server/tests/tests.rs +++ b/server/tests/tests.rs @@ -107,7 +107,7 @@ async fn test_search_filters() { let test_num = rng.random_range(1..100000); let cfg = vectorize_core::config::Config::from_env(); let sql = std::fs::read_to_string("sql/example.sql").unwrap(); - common::exec_psql(&cfg.database_url, &sql); + common::exec_psql(&cfg.database_url, &sql).expect("failed to execute example.sql"); let pool = sqlx::PgPool::connect(&cfg.database_url).await.unwrap(); // test table @@ -251,7 +251,7 @@ async fn test_search_filter_operators() { let test_num = rng.random_range(1..100000); let cfg = vectorize_core::config::Config::from_env(); let sql = std::fs::read_to_string("sql/example.sql").unwrap(); - common::exec_psql(&cfg.database_url, &sql); + common::exec_psql(&cfg.database_url, &sql).expect("failed to execute example.sql"); let pool = sqlx::PgPool::connect(&cfg.database_url).await.unwrap(); // test table diff --git a/server/tests/util.rs b/server/tests/util.rs index 005e660d..4577acca 100644 --- a/server/tests/util.rs +++ b/server/tests/util.rs @@ -146,7 +146,7 @@ pub mod common { .expect("unable to update test data"); } - pub fn exec_psql(conn_string: &str, sql_content: &str) { + pub fn exec_psql(conn_string: &str, sql_content: &str) -> Result<(), String> { let output = Command::new("psql") .arg(conn_string) .arg("-c") @@ -158,10 +158,12 @@ pub mod common { "failed to execute SQL: {}", String::from_utf8_lossy(&output.stderr) ); - panic!( + Err(format!( "failed to execute SQL: {}", String::from_utf8_lossy(&output.stderr) - ); + )) + } else { + Ok(()) } } } From 035fdde40fa6bfb7979754d9f0ec67b7611a93fe Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Thu, 9 Oct 2025 10:03:49 -0500 Subject: [PATCH 12/15] touch --- server/tests/tests.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/server/tests/tests.rs b/server/tests/tests.rs index c06f7212..1d44d081 100644 --- a/server/tests/tests.rs +++ b/server/tests/tests.rs @@ -250,6 +250,7 @@ async fn test_search_filter_operators() { let mut rng = rand::rng(); let test_num = rng.random_range(1..100000); let cfg = vectorize_core::config::Config::from_env(); + // install raw SQL let sql = std::fs::read_to_string("sql/example.sql").unwrap(); common::exec_psql(&cfg.database_url, &sql).expect("failed to execute example.sql"); From ca4428c04b92a41c1955229267122f0b71241e21 Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Thu, 9 Oct 2025 12:26:07 -0500 Subject: [PATCH 13/15] redundant parsing --- core/src/query.rs | 361 ++++------------------------------------------ 1 file changed, 26 insertions(+), 335 deletions(-) diff --git a/core/src/query.rs b/core/src/query.rs index fa06b5e9..a489244e 100644 --- a/core/src/query.rs +++ b/core/src/query.rs @@ -55,7 +55,8 @@ pub enum FilterValueType { } impl FilterValueType { - /// Get the value as a string for SQL binding + /// Get the value as a string for SQL binding (TEST-ONLY - use parameterized queries in production) + #[cfg(test)] pub fn as_sql_value(&self) -> String { match self { FilterValueType::String(s) => s.clone(), @@ -64,6 +65,17 @@ impl FilterValueType { FilterValueType::Boolean(b) => b.to_string(), } } + + /// Get the value for parameterized query binding + /// Returns the value as a type that can be used with sqlx query parameters + pub fn as_bind_value(&self) -> Box { + match self { + FilterValueType::String(s) => Box::new(s.clone()), + FilterValueType::Integer(i) => Box::new(*i), + FilterValueType::Float(f) => Box::new(*f), + FilterValueType::Boolean(b) => Box::new(*b), + } + } } /// Custom deserializer for FilterValue that parses operator.value format @@ -117,9 +129,7 @@ impl<'de> serde::Deserialize<'de> for FilterValue { } else if let Ok(float_val) = val.parse::() { FilterValueType::Float(float_val) } else { - // Validate string values for SQL injection - validate_filter_value(val) - .map_err(|e| de::Error::custom(e.to_string()))?; + // No validation needed with parameterized queries FilterValueType::String(val.to_string()) } } @@ -155,9 +165,7 @@ impl<'de> serde::Deserialize<'de> for FilterValue { } else if let Ok(float_val) = value.parse::() { FilterValueType::Float(float_val) } else { - // Validate string values for SQL injection - validate_filter_value(value) - .map_err(|e| de::Error::custom(e.to_string()))?; + // No validation needed with parameterized queries FilterValueType::String(value.to_string()) }; @@ -194,105 +202,6 @@ pub fn check_input(input: &str) -> Result<()> { } } -/// Validates filter values against SQL injection patterns -/// Returns an error if malicious patterns are detected -pub fn validate_filter_value(value: &str) -> Result<()> { - let value_lower = value.to_lowercase(); - - // Check for common SQL injection patterns - let malicious_patterns = [ - // Basic SQL injection patterns - "';", - "' or", - "' union", - "' and", - "' drop", - "' delete", - "' insert", - "' update", - "' create", - "' alter", - "' exec", - "' execute", - "' script", - "javascript:", - "") || value.contains("<")) { - return Err(anyhow!( - "Potentially malicious input detected: '{}' contains suspicious character combination", - value - )); - } - - // Check for excessive special characters that might indicate encoding attempts - let special_char_count = value - .chars() - .filter(|c| !c.is_alphanumeric() && *c != ' ' && *c != '.' && *c != '-' && *c != '_') - .count(); - if special_char_count > value.len() / 4 { - return Err(anyhow!( - "Potentially malicious input detected: '{}' contains excessive special characters", - value - )); - } - - Ok(()) -} - pub fn create_vectorize_table() -> String { "CREATE TABLE IF NOT EXISTS vectorize.job ( @@ -716,7 +625,7 @@ pub fn join_table_cosine_similarity( join_key: &str, return_columns: &[String], num_results: i32, - where_clause: Option, + filters: &BTreeMap, ) -> String { let cols = &return_columns .iter() @@ -724,11 +633,15 @@ pub fn join_table_cosine_similarity( .collect::>() .join(","); - let where_str = if let Some(w) = where_clause { - prepare_filter(&w, join_key) - } else { - "".to_string() - }; + let mut bind_value_counter: i16 = 2; // Start at $2 since $1 is the vector + let mut where_filter = "WHERE 1=1".to_string(); + for (column, filter_value) in filters.iter() { + let operator = filter_value.operator.to_sql(); + let filt = format!(" AND t0.\"{column}\" {operator} ${bind_value_counter}"); + where_filter.push_str(&filt); + bind_value_counter += 1; + } + let inner_query = format!( " SELECT @@ -748,7 +661,7 @@ pub fn join_table_cosine_similarity( {inner_query} ) t1 INNER JOIN {schema}.{table} t0 on t0.{join_key} = t1.{join_key} - {where_str} + {where_filter} ) t ORDER BY t.similarity_score DESC LIMIT {num_results}; @@ -756,12 +669,6 @@ pub fn join_table_cosine_similarity( ) } -// transform user's where_sql into the format search query expects -fn prepare_filter(filter: &str, pkey: &str) -> String { - let wc = filter.replace(pkey, &format!("t0.{pkey}")); - format!("AND {wc}") -} - #[allow(clippy::too_many_arguments)] pub fn hybrid_search_query( job_name: &str, @@ -1047,222 +954,6 @@ EXECUTE FUNCTION vectorize.handle_update_another_job();" assert_eq!(filter.value.as_sql_value(), "3.141592653589793"); } - // ===== SQL Injection Attack Tests ===== - - #[test] - fn test_filter_value_deserialize_sql_injection_basic() { - // Test basic SQL injection attempts - these should now be rejected - let malicious_inputs = vec![ - "'; DROP TABLE users; --", - "' OR '1'='1", - "' UNION SELECT * FROM users --", - "'; INSERT INTO users VALUES ('hacker', 'password'); --", - "' OR 1=1 --", - ]; - - for malicious_input in malicious_inputs { - let json = format!("\"eq.{}\"", malicious_input); - let result: Result = serde_json::from_str(&json); - assert!( - result.is_err(), - "Should reject malicious input: {}", - malicious_input - ); - let error = result.unwrap_err(); - assert!( - error - .to_string() - .contains("Potentially malicious input detected") - ); - } - } - - #[test] - fn test_filter_value_deserialize_sql_injection_with_operators() { - // Test SQL injection with different operators - let malicious_inputs = vec![ - ("gt", "1; DROP TABLE users; --"), - ("gte", "0 OR 1=1 --"), - ( - "lt", - "999; INSERT INTO users VALUES ('hacker', 'password'); --", - ), - ("lte", "100 UNION SELECT * FROM users --"), - ]; - - for (op, malicious_input) in malicious_inputs { - let json = format!("\"{}.{}\"", op, malicious_input); - let result: Result = serde_json::from_str(&json); - // These should fail because comparison operators require numeric values - assert!( - result.is_err(), - "Should fail for non-numeric value with comparison operator: {}", - malicious_input - ); - } - } - - #[test] - fn test_filter_value_deserialize_sql_injection_script_tags() { - // Test XSS-style attacks that might be used in SQL injection - these should now be rejected - let malicious_inputs = vec![ - "", - "javascript:alert('xss')", - "'; ; --", - "' OR ''='", - ]; - - for malicious_input in malicious_inputs { - let json = format!("\"eq.{}\"", malicious_input); - let result: Result = serde_json::from_str(&json); - assert!( - result.is_err(), - "Should reject malicious input: {}", - malicious_input - ); - let error = result.unwrap_err(); - assert!( - error - .to_string() - .contains("Potentially malicious input detected") - ); - } - } - - #[test] - fn test_filter_value_deserialize_sql_injection_encoding_attempts() { - // Test various encoding attempts to bypass filters - these should now be rejected - let malicious_inputs = vec![ - "%27%20OR%201%3D1%20--", // URL encoded - "' OR 1=1 --", // HTML entity encoded - "' OR CHAR(49)=CHAR(49) --", // CHAR function - "' OR ASCII('a')=97 --", // ASCII function - ]; - - for malicious_input in malicious_inputs { - let json = format!("\"eq.{}\"", malicious_input); - let result: Result = serde_json::from_str(&json); - assert!( - result.is_err(), - "Should reject malicious input: {}", - malicious_input - ); - let error = result.unwrap_err(); - assert!( - error - .to_string() - .contains("Potentially malicious input detected") - ); - } - } - - #[test] - fn test_filter_value_deserialize_sql_injection_time_based() { - // Test time-based SQL injection attempts - these should now be rejected - let malicious_inputs = vec![ - "'; WAITFOR DELAY '00:00:05' --", - "' OR SLEEP(5) --", - "'; SELECT pg_sleep(5); --", - "' OR (SELECT COUNT(*) FROM users WHERE username='admin' AND SLEEP(5))>0 --", - ]; - - for malicious_input in malicious_inputs { - let json = format!("\"eq.{}\"", malicious_input); - let result: Result = serde_json::from_str(&json); - assert!( - result.is_err(), - "Should reject malicious input: {}", - malicious_input - ); - let error = result.unwrap_err(); - assert!( - error - .to_string() - .contains("Potentially malicious input detected") - ); - } - } - - #[test] - fn test_filter_value_deserialize_sql_injection_blind() { - // Test blind SQL injection attempts - these should now be rejected - let malicious_inputs = vec![ - "' AND (SELECT SUBSTRING(password,1,1) FROM users WHERE username='admin')='a' --", - "' OR (SELECT COUNT(*) FROM users WHERE username='admin' AND password LIKE 'a%')>0 --", - "' AND EXISTS(SELECT * FROM users WHERE username='admin' AND password LIKE 'a%') --", - "' OR (SELECT ASCII(SUBSTRING(password,1,1)) FROM users WHERE username='admin')=97 --", - ]; - - for malicious_input in malicious_inputs { - let json = format!("\"eq.{}\"", malicious_input); - let result: Result = serde_json::from_str(&json); - assert!( - result.is_err(), - "Should reject malicious input: {}", - malicious_input - ); - let error = result.unwrap_err(); - assert!( - error - .to_string() - .contains("Potentially malicious input detected") - ); - } - } - - #[test] - fn test_filter_value_deserialize_sql_injection_union() { - // Test UNION-based SQL injection attempts - these should now be rejected - let malicious_inputs = vec![ - "' UNION SELECT username, password FROM users --", - "' UNION ALL SELECT NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL --", - "' UNION SELECT 1,2,3,4,5,6,7,8,9,10 --", - "' UNION SELECT table_name FROM information_schema.tables --", - ]; - - for malicious_input in malicious_inputs { - let json = format!("\"eq.{}\"", malicious_input); - let result: Result = serde_json::from_str(&json); - assert!( - result.is_err(), - "Should reject malicious input: {}", - malicious_input - ); - let error = result.unwrap_err(); - assert!( - error - .to_string() - .contains("Potentially malicious input detected") - ); - } - } - - #[test] - fn test_filter_value_deserialize_sql_injection_error_based() { - // Test error-based SQL injection attempts - these should now be rejected - let malicious_inputs = vec![ - "' AND (SELECT * FROM (SELECT COUNT(*),CONCAT(version(),FLOOR(RAND(0)*2))x FROM information_schema.tables GROUP BY x)a) --", - "' AND EXTRACTVALUE(1, CONCAT(0x7e, (SELECT version()), 0x7e)) --", - "' AND (SELECT * FROM (SELECT COUNT(*),CONCAT(CAST((SELECT version()) AS CHAR),0x7e,FLOOR(RAND(0)*2))x FROM information_schema.tables GROUP BY x)a) --", - ]; - - for malicious_input in malicious_inputs { - let json = format!("\"eq.{}\"", malicious_input); - let result: Result = serde_json::from_str(&json); - assert!( - result.is_err(), - "Should reject malicious input: {}", - malicious_input - ); - let error = result.unwrap_err(); - assert!( - error - .to_string() - .contains("Potentially malicious input detected") - ); - } - } - // ===== Error Handling Tests ===== #[test] From ec1f17fc8fc2f79eaf00e7c75fa597b840e4ea57 Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Thu, 9 Oct 2025 16:44:26 -0500 Subject: [PATCH 14/15] handle concurrency --- server/tests/tests.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/server/tests/tests.rs b/server/tests/tests.rs index 1d44d081..f00008f8 100644 --- a/server/tests/tests.rs +++ b/server/tests/tests.rs @@ -107,7 +107,11 @@ async fn test_search_filters() { let test_num = rng.random_range(1..100000); let cfg = vectorize_core::config::Config::from_env(); let sql = std::fs::read_to_string("sql/example.sql").unwrap(); - common::exec_psql(&cfg.database_url, &sql).expect("failed to execute example.sql"); + if let Err(e) = common::exec_psql(&cfg.database_url, &sql) { + // installation of example.sql could fail due to race conditions + // so we can continue + log::warn!("failed to execute example.sql: {}", e); + } let pool = sqlx::PgPool::connect(&cfg.database_url).await.unwrap(); // test table @@ -252,7 +256,11 @@ async fn test_search_filter_operators() { let cfg = vectorize_core::config::Config::from_env(); // install raw SQL let sql = std::fs::read_to_string("sql/example.sql").unwrap(); - common::exec_psql(&cfg.database_url, &sql).expect("failed to execute example.sql"); + if let Err(e) = common::exec_psql(&cfg.database_url, &sql) { + // installation of example.sql could fail due to race conditions + // so we can continue + log::warn!("failed to execute example.sql: {}", e); + } let pool = sqlx::PgPool::connect(&cfg.database_url).await.unwrap(); // test table From 85abcbb74affa28313d2653385055ae20e4a7aba Mon Sep 17 00:00:00 2001 From: Adam Hendel Date: Thu, 9 Oct 2025 17:28:44 -0500 Subject: [PATCH 15/15] remove dead code --- core/src/transformers/providers/cohere.rs | 8 -------- core/src/transformers/providers/ollama.rs | 7 ------- server/tests/util.rs | 5 ++++- 3 files changed, 4 insertions(+), 16 deletions(-) diff --git a/core/src/transformers/providers/cohere.rs b/core/src/transformers/providers/cohere.rs index 58094916..99d36be2 100644 --- a/core/src/transformers/providers/cohere.rs +++ b/core/src/transformers/providers/cohere.rs @@ -49,14 +49,6 @@ impl From for CohereEmbeddingBody { } } -#[derive(Clone, Debug, Serialize, Deserialize)] -struct CohereEmbeddingResponse { - model: String, - texts: Vec, - input_type: String, - truncate: String, -} - impl CohereProvider { pub fn new(url: Option, api_key: Option) -> Result { let final_url = match url { diff --git a/core/src/transformers/providers/ollama.rs b/core/src/transformers/providers/ollama.rs index efc1360e..5ee8f775 100644 --- a/core/src/transformers/providers/ollama.rs +++ b/core/src/transformers/providers/ollama.rs @@ -8,7 +8,6 @@ use ollama_rs::{ generation::completion::request::GenerationRequest, generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest}, }; -use serde::{Deserialize, Serialize}; use url::Url; pub const OLLAMA_BASE_URL: &str = "http://localhost:3001"; @@ -17,12 +16,6 @@ pub struct OllamaProvider { pub instance: Ollama, } -#[derive(Clone, Debug, Serialize, Deserialize)] -struct ModelInfo { - embedding_dimension: u32, - max_seq_len: u32, -} - impl OllamaProvider { pub fn new(url: Option) -> Self { let url_in = url.unwrap_or_else(|| OLLAMA_BASE_URL.to_string()); diff --git a/server/tests/util.rs b/server/tests/util.rs index 4577acca..be990c79 100644 --- a/server/tests/util.rs +++ b/server/tests/util.rs @@ -78,7 +78,10 @@ pub mod common { if start_time.elapsed() >= timeout_duration { Err(Box::new(std::io::Error::new( std::io::ErrorKind::TimedOut, - "Search request timed out after 10 seconds", + format!( + "Search request timed out after {} seconds", + timeout_duration.as_secs() + ), )))? }