diff --git a/core/src/query.rs b/core/src/query.rs index 3cc32afb..a489244e 100644 --- a/core/src/query.rs +++ b/core/src/query.rs @@ -10,54 +10,174 @@ use tiktoken_rs::cl100k_base; pub const VECTORIZE_SCHEMA: &str = "vectorize"; static TRIGGER_FN_PREFIX: &str = "vectorize.handle_update_"; -#[derive(Serialize, Debug, Clone)] -#[serde(untagged)] -pub enum FilterValue { +/// Filter operators supported by the search API +#[derive(Debug, Clone, PartialEq, Serialize)] +pub enum FilterOperator { + /// Equal to (=) + Equal, + /// Greater than (>) + GreaterThan, + /// Greater than or equal (>=) + GreaterThanOrEqual, + /// Less than (<) + LessThan, + /// Less than or equal (<=) + LessThanOrEqual, +} + +impl FilterOperator { + /// Convert operator to SQL operator string + pub fn to_sql(&self) -> &'static str { + match self { + FilterOperator::Equal => "=", + FilterOperator::GreaterThan => ">", + FilterOperator::GreaterThanOrEqual => ">=", + FilterOperator::LessThan => "<", + FilterOperator::LessThanOrEqual => "<=", + } + } +} + +/// A filter value with an operator +#[derive(Debug, Clone, Serialize)] +pub struct FilterValue { + pub operator: FilterOperator, + pub value: FilterValueType, +} + +/// The actual value stored in a filter +#[derive(Debug, Clone, PartialEq, Serialize)] +pub enum FilterValueType { String(String), Integer(i64), Float(f64), Boolean(bool), } +impl FilterValueType { + /// Get the value as a string for SQL binding (TEST-ONLY - use parameterized queries in production) + #[cfg(test)] + pub fn as_sql_value(&self) -> String { + match self { + FilterValueType::String(s) => s.clone(), + FilterValueType::Integer(i) => i.to_string(), + FilterValueType::Float(f) => f.to_string(), + FilterValueType::Boolean(b) => b.to_string(), + } + } + + /// Get the value for parameterized query binding + /// Returns the value as a type that can be used with sqlx query parameters + pub fn as_bind_value(&self) -> Box { + match self { + FilterValueType::String(s) => Box::new(s.clone()), + FilterValueType::Integer(i) => Box::new(*i), + FilterValueType::Float(f) => Box::new(*f), + FilterValueType::Boolean(b) => Box::new(*b), + } + } +} + +/// Custom deserializer for FilterValue that parses operator.value format impl<'de> serde::Deserialize<'de> for FilterValue { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { - let s = String::deserialize(deserializer)?; + use serde::de::{self, Visitor}; + use std::fmt; - // Try to parse as boolean first - if let Ok(b) = s.parse::() { - return Ok(FilterValue::Boolean(b)); - } + struct FilterValueVisitor; - // Try to parse as integer - if let Ok(i) = s.parse::() { - return Ok(FilterValue::Integer(i)); - } + impl<'de> Visitor<'de> for FilterValueVisitor { + type Value = FilterValue; - // Try to parse as float - if let Ok(f) = s.parse::() { - return Ok(FilterValue::Float(f)); - } + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string in format 'operator.value' or just 'value'") + } - // Fall back to string + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + if let Some(dot_pos) = value.find('.') { + let operator_str = &value[..dot_pos]; + let val = &value[dot_pos + 1..]; - Ok(FilterValue::String(s)) - } -} + let operator = match operator_str { + "eq" => FilterOperator::Equal, + "gt" => FilterOperator::GreaterThan, + "gte" => FilterOperator::GreaterThanOrEqual, + "lt" => FilterOperator::LessThan, + "lte" => FilterOperator::LessThanOrEqual, + _ => { + return Err(de::Error::custom(format!( + "Unknown operator: {}", + operator_str + ))); + } + }; -impl FilterValue { - pub fn bind_to_query<'q>( - &'q self, - query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>, - ) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> { - match self { - FilterValue::String(s) => query.bind(s), - FilterValue::Integer(i) => query.bind(*i), - FilterValue::Float(f) => query.bind(*f), - FilterValue::Boolean(b) => query.bind(*b), + // Parse the value based on the operator + let parsed_value = match operator { + FilterOperator::Equal => { + // For equality, try to parse as boolean first, then number, fallback to string + if let Ok(bool_val) = val.parse::() { + FilterValueType::Boolean(bool_val) + } else if let Ok(int_val) = val.parse::() { + FilterValueType::Integer(int_val) + } else if let Ok(float_val) = val.parse::() { + FilterValueType::Float(float_val) + } else { + // No validation needed with parameterized queries + FilterValueType::String(val.to_string()) + } + } + FilterOperator::GreaterThan + | FilterOperator::GreaterThanOrEqual + | FilterOperator::LessThan + | FilterOperator::LessThanOrEqual => { + // For comparison operators, require numeric values + if let Ok(int_val) = val.parse::() { + FilterValueType::Integer(int_val) + } else if let Ok(float_val) = val.parse::() { + FilterValueType::Float(float_val) + } else { + return Err(de::Error::custom(format!( + "Comparison operators (gt, gte, lt, lte) require numeric values, got: '{}'", + val + ))); + } + } + }; + + Ok(FilterValue { + operator, + value: parsed_value, + }) + } else { + // Default to equality if no operator specified + // Try to parse as boolean first, then number, fallback to string + let parsed_value = if let Ok(bool_val) = value.parse::() { + FilterValueType::Boolean(bool_val) + } else if let Ok(int_val) = value.parse::() { + FilterValueType::Integer(int_val) + } else if let Ok(float_val) = value.parse::() { + FilterValueType::Float(float_val) + } else { + // No validation needed with parameterized queries + FilterValueType::String(value.to_string()) + }; + + Ok(FilterValue { + operator: FilterOperator::Equal, + value: parsed_value, + }) + } + } } + + deserializer.deserialize_str(FilterValueVisitor) } } @@ -505,7 +625,7 @@ pub fn join_table_cosine_similarity( join_key: &str, return_columns: &[String], num_results: i32, - where_clause: Option, + filters: &BTreeMap, ) -> String { let cols = &return_columns .iter() @@ -513,11 +633,15 @@ pub fn join_table_cosine_similarity( .collect::>() .join(","); - let where_str = if let Some(w) = where_clause { - prepare_filter(&w, join_key) - } else { - "".to_string() - }; + let mut bind_value_counter: i16 = 2; // Start at $2 since $1 is the vector + let mut where_filter = "WHERE 1=1".to_string(); + for (column, filter_value) in filters.iter() { + let operator = filter_value.operator.to_sql(); + let filt = format!(" AND t0.\"{column}\" {operator} ${bind_value_counter}"); + where_filter.push_str(&filt); + bind_value_counter += 1; + } + let inner_query = format!( " SELECT @@ -537,7 +661,7 @@ pub fn join_table_cosine_similarity( {inner_query} ) t1 INNER JOIN {schema}.{table} t0 on t0.{join_key} = t1.{join_key} - {where_str} + {where_filter} ) t ORDER BY t.similarity_score DESC LIMIT {num_results}; @@ -545,12 +669,6 @@ pub fn join_table_cosine_similarity( ) } -// transform user's where_sql into the format search query expects -fn prepare_filter(filter: &str, pkey: &str) -> String { - let wc = filter.replace(pkey, &format!("t0.{pkey}")); - format!("AND {wc}") -} - #[allow(clippy::too_many_arguments)] pub fn hybrid_search_query( job_name: &str, @@ -573,8 +691,9 @@ pub fn hybrid_search_query( let mut bind_value_counter: i16 = 3; let mut where_filter = "WHERE 1=1".to_string(); - for column in filters.keys() { - let filt = format!(" AND t0.\"{column}\" = ${bind_value_counter}"); + for (column, filter_value) in filters.iter() { + let operator = filter_value.operator.to_sql(); + let filt = format!(" AND t0.\"{column}\" {operator} ${bind_value_counter}"); where_filter.push_str(&filt); bind_value_counter += 1; } @@ -643,6 +762,7 @@ pub fn hybrid_search_query( #[cfg(test)] mod tests { use super::*; + use serde_json; #[test] fn test_create_update_trigger_single() { @@ -675,4 +795,615 @@ EXECUTE FUNCTION vectorize.handle_update_another_job();" let result = create_event_trigger(job_name, "myschema", table_name, "INSERT"); assert_eq!(expected, result); } + + // ===== FilterValue Deserialization Tests ===== + + #[test] + fn test_filter_value_deserialize_equality_string() { + let json = "\"eq.hello\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "hello"); + } + + #[test] + fn test_filter_value_deserialize_equality_integer() { + let json = "\"eq.42\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "42"); + } + + #[test] + fn test_filter_value_deserialize_equality_float() { + let json = "\"eq.3.14\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "3.14"); + } + + #[test] + fn test_filter_value_deserialize_greater_than() { + let json = "\"gt.100\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::GreaterThan); + assert_eq!(filter.value.as_sql_value(), "100"); + } + + #[test] + fn test_filter_value_deserialize_greater_than_or_equal() { + let json = "\"gte.50.5\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::GreaterThanOrEqual); + assert_eq!(filter.value.as_sql_value(), "50.5"); + } + + #[test] + fn test_filter_value_deserialize_less_than() { + let json = "\"lt.25\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::LessThan); + assert_eq!(filter.value.as_sql_value(), "25"); + } + + #[test] + fn test_filter_value_deserialize_less_than_or_equal() { + let json = "\"lte.10.0\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::LessThanOrEqual); + assert_eq!(filter.value.as_sql_value(), "10"); + } + + #[test] + fn test_filter_value_deserialize_default_equality() { + let json = "\"hello\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "hello"); + } + + #[test] + fn test_filter_value_deserialize_default_equality_numeric() { + let json = "\"42\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "42"); + } + + // ===== Edge Case Tests ===== + + #[test] + fn test_filter_value_deserialize_empty_string() { + let json = "\"eq.\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), ""); + } + + #[test] + fn test_filter_value_deserialize_zero_values() { + let json_int = "\"eq.0\""; + let filter_int: FilterValue = serde_json::from_str(json_int).unwrap(); + assert_eq!(filter_int.operator, FilterOperator::Equal); + assert_eq!(filter_int.value.as_sql_value(), "0"); + + let json_float = "\"eq.0.0\""; + let filter_float: FilterValue = serde_json::from_str(json_float).unwrap(); + assert_eq!(filter_float.operator, FilterOperator::Equal); + assert_eq!(filter_float.value.as_sql_value(), "0"); + } + + #[test] + fn test_filter_value_deserialize_negative_values() { + let json_int = "\"eq.-42\""; + let filter_int: FilterValue = serde_json::from_str(json_int).unwrap(); + assert_eq!(filter_int.operator, FilterOperator::Equal); + assert_eq!(filter_int.value.as_sql_value(), "-42"); + + let json_float = "\"eq.-3.14\""; + let filter_float: FilterValue = serde_json::from_str(json_float).unwrap(); + assert_eq!(filter_float.operator, FilterOperator::Equal); + assert_eq!(filter_float.value.as_sql_value(), "-3.14"); + } + + #[test] + fn test_filter_value_deserialize_special_characters() { + let json = "\"eq.hello-world_123\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "hello-world_123"); + } + + #[test] + fn test_filter_value_deserialize_unicode_characters() { + let json = "\"eq.测试\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "测试"); + } + + #[test] + fn test_filter_value_deserialize_whitespace_values() { + let json = "\"eq. hello \""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), " hello "); + } + + #[test] + fn test_filter_value_deserialize_scientific_notation() { + let json = "\"eq.1e5\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "100000"); + } + + #[test] + fn test_filter_value_deserialize_large_numbers() { + let json = "\"eq.9223372036854775807\""; // i64::MAX + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "9223372036854775807"); + } + + #[test] + fn test_filter_value_deserialize_precision_float() { + let json = "\"eq.3.141592653589793\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "3.141592653589793"); + } + + // ===== Error Handling Tests ===== + + #[test] + fn test_filter_value_deserialize_invalid_operator() { + let json = "\"invalid.42\""; + let result: Result = serde_json::from_str(json); + assert!(result.is_err(), "Should fail for invalid operator"); + let error = result.unwrap_err(); + assert!(error.to_string().contains("Unknown operator")); + } + + #[test] + fn test_filter_value_deserialize_comparison_with_string() { + // Test that comparison operators fail with non-numeric values + let test_cases = vec![ + ("gt", "hello"), + ("gte", "world"), + ("lt", "test"), + ("lte", "string"), + ]; + + for (op, value) in test_cases { + let json = format!("\"{}.{}\"", op, value); + let result: Result = serde_json::from_str(&json); + assert!( + result.is_err(), + "Should fail for non-numeric value with {} operator", + op + ); + let error = result.unwrap_err(); + assert!(error.to_string().contains("require numeric values")); + } + } + + #[test] + fn test_filter_value_deserialize_malformed_json() { + // Test various malformed JSON inputs + let malformed_inputs = vec![ + ("\"eq.42", false), // Missing closing quote + ("eq.42\"", false), // Missing opening quote + ("\"eq.42\"", true), // This should work + ("\"eq.42.extra\"", true), // Extra dot should work as string + ("\"eq.\"", true), // Empty value should work + ("\".42\"", false), // Missing operator should fail + ("\"eq\"", true), // Missing dot and value should work as string + ]; + + for (input, should_succeed) in malformed_inputs { + let result: Result = serde_json::from_str(input); + if should_succeed { + assert!(result.is_ok(), "Should succeed for input: {}", input); + } else { + assert!( + result.is_err(), + "Should fail for malformed input: {}", + input + ); + } + } + } + + #[test] + fn test_filter_value_deserialize_empty_input() { + let json = "\"\""; + let result: Result = serde_json::from_str(json); + // Empty input should succeed and default to equality with empty string + assert!(result.is_ok(), "Empty input should succeed"); + let filter = result.unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), ""); + } + + #[test] + fn test_filter_value_deserialize_just_dot() { + let json = "\".\""; + let result: Result = serde_json::from_str(json); + assert!(result.is_err(), "Should fail for input with just a dot"); + } + + #[test] + fn test_filter_value_deserialize_multiple_dots() { + let json = "\"eq.42.extra\""; + let result: Result = serde_json::from_str(json); + // Multiple dots should succeed and treat the whole thing as a string + assert!( + result.is_ok(), + "Should succeed for input with multiple dots" + ); + let filter = result.unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value.as_sql_value(), "42.extra"); + } + + #[test] + fn test_filter_value_deserialize_case_sensitive_operators() { + // Test that operators are case sensitive + let case_variations = vec!["EQ.42", "GT.42", "GTE.42", "LT.42", "LTE.42"]; + + for input in case_variations { + let json = format!("\"{}\"", input); + let result: Result = serde_json::from_str(&json); + assert!( + result.is_err(), + "Should fail for uppercase operator: {}", + input + ); + } + } + + #[test] + fn test_filter_value_deserialize_whitespace_in_operator() { + // Test operators with whitespace - these actually succeed as they're treated as strings + let whitespace_inputs = vec![ + ("\" eq.42\"", false), // Leading space - fails + ("\"eq .42\"", false), // Space before dot - fails + ("\"eq. 42\"", true), // Space after dot - succeeds as string + ("\"eq.42 \"", true), // Trailing space - succeeds as string + ]; + + for (input, should_succeed) in whitespace_inputs { + let result: Result = serde_json::from_str(input); + if should_succeed { + assert!(result.is_ok(), "Should succeed for input: {}", input); + } else { + assert!(result.is_err(), "Should fail for input: {}", input); + } + } + } + + // ===== Numeric Parsing Edge Cases ===== + + #[test] + fn test_filter_value_deserialize_numeric_boundaries() { + // Test integer boundaries + let json_max_i64 = "\"eq.9223372036854775807\""; // i64::MAX + let filter_max: FilterValue = serde_json::from_str(json_max_i64).unwrap(); + assert_eq!(filter_max.value.as_sql_value(), "9223372036854775807"); + + let json_min_i64 = "\"eq.-9223372036854775808\""; // i64::MIN + let filter_min: FilterValue = serde_json::from_str(json_min_i64).unwrap(); + assert_eq!(filter_min.value.as_sql_value(), "-9223372036854775808"); + } + + #[test] + fn test_filter_value_deserialize_float_precision() { + // Test various float precision cases + let test_cases = vec![ + ("0.0", "0"), + ("0.1", "0.1"), + ("0.01", "0.01"), + ("0.001", "0.001"), + ("1.0", "1"), + ("1.1", "1.1"), + ("1.11", "1.11"), + ("1.111", "1.111"), + ]; + + for (input, expected) in test_cases { + let json = format!("\"eq.{}\"", input); + let filter: FilterValue = serde_json::from_str(&json).unwrap(); + assert_eq!( + filter.value.as_sql_value(), + expected, + "Failed for input: {}", + input + ); + } + } + + #[test] + fn test_filter_value_deserialize_scientific_notation_edge_cases() { + // Test scientific notation edge cases + let test_cases = vec![ + ("1e0", "1"), + ("1e1", "10"), + ("1e-1", "0.1"), + ("1e-10", "0.0000000001"), + ("1e10", "10000000000"), + ("1.5e2", "150"), + ("1.5e-2", "0.015"), + ]; + + for (input, expected) in test_cases { + let json = format!("\"eq.{}\"", input); + let filter: FilterValue = serde_json::from_str(&json).unwrap(); + assert_eq!( + filter.value.as_sql_value(), + expected, + "Failed for input: {}", + input + ); + } + } + + #[test] + fn test_filter_value_deserialize_hex_numbers() { + // Test that hex numbers are treated as strings (not parsed as integers) + let json = "\"eq.0xFF\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.value.as_sql_value(), "0xFF"); + } + + #[test] + fn test_filter_value_deserialize_octal_numbers() { + // Test that octal numbers are parsed as integers + let json = "\"eq.0777\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.value.as_sql_value(), "777"); + } + + #[test] + fn test_filter_value_deserialize_binary_numbers() { + // Test that binary numbers are treated as strings + let json = "\"eq.0b1010\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.value.as_sql_value(), "0b1010"); + } + + #[test] + fn test_filter_value_deserialize_numeric_with_leading_zeros() { + // Test numbers with leading zeros + let json = "\"eq.007\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.value.as_sql_value(), "7"); + } + + #[test] + fn test_filter_value_deserialize_numeric_with_trailing_zeros() { + // Test numbers with trailing zeros + let json = "\"eq.42.000\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.value.as_sql_value(), "42"); + } + + #[test] + fn test_filter_value_deserialize_numeric_with_plus_sign() { + // Test numbers with explicit plus sign + let json = "\"+42\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.value.as_sql_value(), "42"); + } + + #[test] + fn test_filter_value_deserialize_numeric_with_plus_sign_float() { + // Test floats with explicit plus sign (should work as default equality) + let json = "\"+3.14\""; + let result: Result = serde_json::from_str(json); + // This should fail because "+3" is not a valid operator + assert!( + result.is_err(), + "Should fail for input with plus sign as operator" + ); + } + + #[test] + fn test_filter_value_deserialize_numeric_infinity() { + // Test infinity values (should be parsed as float infinity) + let json = "\"eq.infinity\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.value.as_sql_value(), "inf"); + } + + #[test] + fn test_filter_value_deserialize_numeric_nan() { + // Test NaN values (should be parsed as float NaN) + let json = "\"eq.nan\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + // NaN comparison requires special handling + match filter.value { + FilterValueType::Float(f) => assert!(f.is_nan(), "Expected NaN"), + _ => panic!("Expected Float(NaN)"), + } + assert_eq!(filter.value.as_sql_value(), "NaN"); + } + + #[test] + fn test_filter_value_deserialize_numeric_very_small() { + // Test very small numbers + let json = "\"eq.0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + // Very small numbers get parsed as floats and converted to "0" when using to_string() + assert_eq!(filter.value.as_sql_value(), "0"); + } + + #[test] + fn test_filter_value_deserialize_numeric_very_large() { + // Test very large numbers (using f64::MAX as a reasonable upper bound) + let json = "\"eq.1.7976931348623157e308\""; // f64::MAX + let filter: FilterValue = serde_json::from_str(json).unwrap(); + // Very large numbers get parsed as floats and converted to a long decimal string when using to_string() + assert_eq!( + filter.value.as_sql_value(), + "179769313486231570000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" + ); + } + + // ===== Boolean Filter Value Tests ===== + + #[test] + fn test_filter_value_deserialize_boolean_true() { + let json = "\"eq.true\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value, FilterValueType::Boolean(true)); + assert_eq!(filter.value.as_sql_value(), "true"); + } + + #[test] + fn test_filter_value_deserialize_boolean_false() { + let json = "\"eq.false\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value, FilterValueType::Boolean(false)); + assert_eq!(filter.value.as_sql_value(), "false"); + } + + #[test] + fn test_filter_value_deserialize_boolean_default_true() { + let json = "\"true\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value, FilterValueType::Boolean(true)); + assert_eq!(filter.value.as_sql_value(), "true"); + } + + #[test] + fn test_filter_value_deserialize_boolean_default_false() { + let json = "\"false\""; + let filter: FilterValue = serde_json::from_str(json).unwrap(); + assert_eq!(filter.operator, FilterOperator::Equal); + assert_eq!(filter.value, FilterValueType::Boolean(false)); + assert_eq!(filter.value.as_sql_value(), "false"); + } + + #[test] + fn test_filter_value_deserialize_boolean_case_sensitive() { + // Test that boolean parsing is case sensitive - uppercase values become strings + let test_cases = vec![ + ("\"eq.True\"", FilterValueType::String("True".to_string())), + ("\"eq.False\"", FilterValueType::String("False".to_string())), + ("\"eq.TRUE\"", FilterValueType::String("TRUE".to_string())), + ("\"eq.FALSE\"", FilterValueType::String("FALSE".to_string())), + ("\"eq.true\"", FilterValueType::Boolean(true)), + ("\"eq.false\"", FilterValueType::Boolean(false)), + ]; + + for (input, expected_value) in test_cases { + let filter: FilterValue = serde_json::from_str(input).unwrap(); + assert_eq!(filter.value, expected_value); + } + } + + #[test] + fn test_filter_value_deserialize_boolean_with_whitespace() { + // Test boolean parsing with whitespace + let test_cases = vec![ + ("\"eq. true\"", true), // Space before true - should succeed as string + ("\"eq.false \"", true), // Space after false - should succeed as string + ("\"eq. true \"", true), // Spaces around true - should succeed as string + ]; + + for (input, should_succeed) in test_cases { + let result: Result = serde_json::from_str(input); + if should_succeed { + assert!(result.is_ok(), "Should succeed for input: {}", input); + let filter = result.unwrap(); + // With whitespace, it should be parsed as a string, not boolean + assert!(matches!(filter.value, FilterValueType::String(_))); + } else { + assert!(result.is_err(), "Should fail for input: {}", input); + } + } + } + + #[test] + fn test_filter_value_deserialize_boolean_vs_string() { + // Test that "true" and "false" strings are not parsed as booleans + let test_cases = vec![ + ( + "\"eq.true_string\"", + FilterValueType::String("true_string".to_string()), + ), + ( + "\"eq.false_string\"", + FilterValueType::String("false_string".to_string()), + ), + ( + "\"eq.true123\"", + FilterValueType::String("true123".to_string()), + ), + ( + "\"eq.false456\"", + FilterValueType::String("false456".to_string()), + ), + ]; + + for (input, expected_value) in test_cases { + let filter: FilterValue = serde_json::from_str(input).unwrap(); + assert_eq!(filter.value, expected_value); + } + } + + #[test] + fn test_filter_value_deserialize_boolean_vs_numeric() { + // Test that numeric values are still parsed as numbers, not booleans + let test_cases = vec![ + ("\"eq.1\"", FilterValueType::Integer(1)), + ("\"eq.0\"", FilterValueType::Integer(0)), + ("\"eq.1.0\"", FilterValueType::Float(1.0)), + ("\"eq.0.0\"", FilterValueType::Float(0.0)), + ]; + + for (input, expected_value) in test_cases { + let filter: FilterValue = serde_json::from_str(input).unwrap(); + assert_eq!(filter.value, expected_value); + } + } + + #[test] + fn test_filter_value_deserialize_boolean_comparison_operators() { + // Test that comparison operators with boolean values fail (as they should require numeric values) + let test_cases = vec!["gt.true", "gte.false", "lt.true", "lte.false"]; + + for input in test_cases { + let json = format!("\"{}\"", input); + let result: Result = serde_json::from_str(&json); + assert!( + result.is_err(), + "Should fail for boolean value with comparison operator: {}", + input + ); + let error = result.unwrap_err(); + assert!(error.to_string().contains("require numeric values")); + } + } + + #[test] + fn test_filter_value_deserialize_boolean_edge_cases() { + // Test edge cases for boolean parsing + let test_cases = vec![ + ("\"eq.true\"", FilterValueType::Boolean(true)), + ("\"eq.false\"", FilterValueType::Boolean(false)), + ("\"true\"", FilterValueType::Boolean(true)), + ("\"false\"", FilterValueType::Boolean(false)), + ]; + + for (input, expected_value) in test_cases { + let filter: FilterValue = serde_json::from_str(input).unwrap(); + assert_eq!(filter.value, expected_value); + assert_eq!(filter.operator, FilterOperator::Equal); + } + } } diff --git a/core/src/transformers/providers/cohere.rs b/core/src/transformers/providers/cohere.rs index 58094916..99d36be2 100644 --- a/core/src/transformers/providers/cohere.rs +++ b/core/src/transformers/providers/cohere.rs @@ -49,14 +49,6 @@ impl From for CohereEmbeddingBody { } } -#[derive(Clone, Debug, Serialize, Deserialize)] -struct CohereEmbeddingResponse { - model: String, - texts: Vec, - input_type: String, - truncate: String, -} - impl CohereProvider { pub fn new(url: Option, api_key: Option) -> Result { let final_url = match url { diff --git a/core/src/transformers/providers/ollama.rs b/core/src/transformers/providers/ollama.rs index efc1360e..5ee8f775 100644 --- a/core/src/transformers/providers/ollama.rs +++ b/core/src/transformers/providers/ollama.rs @@ -8,7 +8,6 @@ use ollama_rs::{ generation::completion::request::GenerationRequest, generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest}, }; -use serde::{Deserialize, Serialize}; use url::Url; pub const OLLAMA_BASE_URL: &str = "http://localhost:3001"; @@ -17,12 +16,6 @@ pub struct OllamaProvider { pub instance: Ollama, } -#[derive(Clone, Debug, Serialize, Deserialize)] -struct ModelInfo { - embedding_dimension: u32, - max_seq_len: u32, -} - impl OllamaProvider { pub fn new(url: Option) -> Self { let url_in = url.unwrap_or_else(|| OLLAMA_BASE_URL.to_string()); diff --git a/server/src/routes/search.rs b/server/src/routes/search.rs index 5bdd3e10..18c9ab4e 100644 --- a/server/src/routes/search.rs +++ b/server/src/routes/search.rs @@ -8,7 +8,7 @@ use std::sync::Arc; use tokio::sync::RwLock; use utoipa::ToSchema; use uuid::Uuid; -use vectorize_core::query; +use vectorize_core::query::{self, FilterValue}; use vectorize_core::transformers::providers::prepare_generic_embedding_request; use vectorize_core::transformers::types::Inputs; use vectorize_core::types::VectorizeJob; @@ -28,7 +28,7 @@ pub struct SearchRequest { #[serde(default = "default_fts_wt")] pub fts_wt: f32, #[serde(flatten, default)] - pub filters: BTreeMap, + pub filters: BTreeMap, } fn default_semantic_wt() -> f32 { @@ -86,13 +86,10 @@ pub async fn search( // check inputs and filters are valid if they exist and create a SQL string for them query::check_input(&payload.job_name)?; if !payload.filters.is_empty() { - for (key, value) in &payload.filters { - // validate key and value + for key in payload.filters.keys() { + // validate key only (column names should be alphanumeric + underscore) query::check_input(key)?; - if let query::FilterValue::String(value) = value { - // only need to check the value if it is a raw string - query::check_input(value)?; - } + // Note: filter values are validated during deserialization in FilterValue } } @@ -149,9 +146,14 @@ pub async fn search( .bind(&embeddings.embeddings[0]) .bind(&payload.query); - // Bind filter values using the same BTreeMap instance used by the query builder + // Bind filter values for value in payload.filters.values() { - prepared_query = value.bind_to_query(prepared_query); + prepared_query = match &value.value { + query::FilterValueType::String(s) => prepared_query.bind(s), + query::FilterValueType::Integer(i) => prepared_query.bind(i), + query::FilterValueType::Float(f) => prepared_query.bind(f), + query::FilterValueType::Boolean(b) => prepared_query.bind(b), + }; } let results = prepared_query.fetch_all(&**pool).await?; diff --git a/server/tests/tests.rs b/server/tests/tests.rs index 3c92a211..f00008f8 100644 --- a/server/tests/tests.rs +++ b/server/tests/tests.rs @@ -107,7 +107,11 @@ async fn test_search_filters() { let test_num = rng.random_range(1..100000); let cfg = vectorize_core::config::Config::from_env(); let sql = std::fs::read_to_string("sql/example.sql").unwrap(); - common::exec_psql(&cfg.database_url, &sql); + if let Err(e) = common::exec_psql(&cfg.database_url, &sql) { + // installation of example.sql could fail due to race conditions + // so we can continue + log::warn!("failed to execute example.sql: {}", e); + } let pool = sqlx::PgPool::connect(&cfg.database_url).await.unwrap(); // test table @@ -149,8 +153,8 @@ async fn test_search_filters() { resp.status() ); - // filter a query by product_category - let params = format!("job_name={job_name}&query=pen&product_category=electronics",); + // filter a query by product_category (using eq operator) + let params = format!("job_name={job_name}&query=pen&product_category=eq.electronics",); let search_results = common::search_with_retry(¶ms, 9).await.unwrap(); assert_eq!(search_results.len(), 9); @@ -159,8 +163,7 @@ async fn test_search_filters() { assert_eq!(result["product_category"].as_str().unwrap(), "electronics"); } - // filter by price - let params = format!("job_name={job_name}&query=electronics&price=25"); + let params = format!("job_name={job_name}&query=electronics&price=eq.25"); let search_results = common::search_with_retry(¶ms, 2).await.unwrap(); assert_eq!(search_results.len(), 2); assert_eq!( @@ -171,6 +174,171 @@ async fn test_search_filters() { search_results[1]["product_name"].as_str().unwrap(), "Alarm Clock" ); + + // test greater than or equal operator + let params = format!("job_name={job_name}&query=electronics&price=gte.25&limit=5"); + let search_results = common::search_with_retry(¶ms, 5).await.unwrap(); + assert_eq!(search_results.len(), 5); + + // test backward compatibility - no operator should default to equality + let params = format!("job_name={job_name}&query=pen&product_category=electronics",); + let search_results = common::search_with_retry(¶ms, 9).await.unwrap(); + assert_eq!(search_results.len(), 9); + for result in search_results { + assert_eq!(result["product_category"].as_str().unwrap(), "electronics"); + } + + // test multiple filters - category first, then price + let params = format!( + "job_name={job_name}&query=electronics&product_category=eq.electronics&price=gte.25" + ); + let search_results_category_first = common::search_with_retry(¶ms, 5).await.unwrap(); + assert_eq!(search_results_category_first.len(), 5); + for result in &search_results_category_first { + assert_eq!(result["product_category"].as_str().unwrap(), "electronics"); + assert!(result["price"].as_f64().unwrap() >= 25.0); + } + + // test multiple filters - price first, then category (different order) + let params = format!( + "job_name={job_name}&query=electronics&price=gte.25&product_category=eq.electronics" + ); + let search_results_price_first = common::search_with_retry(¶ms, 5).await.unwrap(); + assert_eq!(search_results_price_first.len(), 5); + for result in &search_results_price_first { + assert_eq!(result["product_category"].as_str().unwrap(), "electronics"); + assert!(result["price"].as_f64().unwrap() >= 25.0); + } + + // verify that both filter orders produce the same results + assert_eq!( + search_results_category_first.len(), + search_results_price_first.len() + ); + // Sort both results by product_id to ensure consistent comparison + let mut category_first_sorted = search_results_category_first.clone(); + let mut price_first_sorted = search_results_price_first.clone(); + category_first_sorted.sort_by(|a, b| { + a["product_id"] + .as_i64() + .unwrap() + .cmp(&b["product_id"].as_i64().unwrap()) + }); + price_first_sorted.sort_by(|a, b| { + a["product_id"] + .as_i64() + .unwrap() + .cmp(&b["product_id"].as_i64().unwrap()) + }); + + for (i, (result1, result2)) in category_first_sorted + .iter() + .zip(price_first_sorted.iter()) + .enumerate() + { + assert_eq!( + result1["product_id"], result2["product_id"], + "Product IDs should match at index {}", + i + ); + assert_eq!( + result1["product_name"], result2["product_name"], + "Product names should match at index {}", + i + ); + } +} + +#[tokio::test] +async fn test_search_filter_operators() { + let mut rng = rand::rng(); + let test_num = rng.random_range(1..100000); + let cfg = vectorize_core::config::Config::from_env(); + // install raw SQL + let sql = std::fs::read_to_string("sql/example.sql").unwrap(); + if let Err(e) = common::exec_psql(&cfg.database_url, &sql) { + // installation of example.sql could fail due to race conditions + // so we can continue + log::warn!("failed to execute example.sql: {}", e); + } + + let pool = sqlx::PgPool::connect(&cfg.database_url).await.unwrap(); + // test table + let table = format!("test_filter_ops_{test_num}"); + let drop_sql = format!("DROP TABLE IF EXISTS public.{table};"); + let create_sql = + format!("CREATE TABLE public.{table} (LIKE public.my_products INCLUDING ALL);"); + let insert_sql = format!("INSERT INTO public.{table} SELECT * FROM public.my_products;"); + + sqlx::query(&drop_sql).execute(&pool).await.unwrap(); + sqlx::query(&create_sql).execute(&pool).await.unwrap(); + sqlx::query(&insert_sql).execute(&pool).await.unwrap(); + + // initialize search job + let job_name = format!("test_filter_ops_{test_num}"); + let payload = json!({ + "job_name": job_name, + "src_table": table, + "src_schema": "public", + "src_columns": ["description"], + "primary_key": "product_id", + "update_time_col": "updated_at", + "model": "sentence-transformers/all-MiniLM-L6-v2" + }); + + let client = reqwest::Client::new(); + let resp = client + .post("http://localhost:8080/api/v1/table") + .header("Content-Type", "application/json") + .json(&payload) + .send() + .await + .expect("Failed to send request"); + + assert_eq!( + resp.status(), + reqwest::StatusCode::OK, + "Response status: {:?}", + resp.status() + ); + + // Test different operators + // Greater than + let params = format!("job_name={job_name}&query=electronics&price=gt.20&limit=100"); + let search_results = common::search_with_retry(¶ms, 14).await.unwrap(); + assert_eq!(search_results.len(), 14); + + // Less than or equal + let params = format!("job_name={job_name}&query=electronics&price=lte.25&limit=100"); + let search_results = common::search_with_retry(¶ms, 30).await.unwrap(); + assert_eq!(search_results.len(), 30); + + // Test float values + let params = format!("job_name={job_name}&query=electronics&price=gte.24.5&limit=1000"); + let search_results = common::search_with_retry(¶ms, 12).await.unwrap(); + assert_eq!(search_results.len(), 12); + + // Test invalid operator (should return error) + let params = format!("job_name={job_name}&query=electronics&price=invalid.25"); + let response = client + .get(&format!("http://localhost:8080/api/v1/search?{}", params)) + .send() + .await + .expect("Failed to send request"); + + // Should return an error for invalid operator + assert!(response.status().is_client_error() || response.status().is_server_error()); + + // Test non-numeric value with comparison operator (should return error) + let params = format!("job_name={job_name}&query=electronics&price=gt.abc"); + let response = client + .get(&format!("http://localhost:8080/api/v1/search?{}", params)) + .send() + .await + .expect("Failed to send request"); + + // Should return an error for non-numeric value with comparison operator + assert!(response.status().is_client_error() || response.status().is_server_error()); } /// proxy is an incomplete feature diff --git a/server/tests/util.rs b/server/tests/util.rs index d2f6cf69..be990c79 100644 --- a/server/tests/util.rs +++ b/server/tests/util.rs @@ -76,7 +76,13 @@ pub mod common { // Check if we've exceeded the timeout if start_time.elapsed() >= timeout_duration { - panic!("Search request timed out after 10 seconds"); + Err(Box::new(std::io::Error::new( + std::io::ErrorKind::TimedOut, + format!( + "Search request timed out after {} seconds", + timeout_duration.as_secs() + ), + )))? } // Wait before retrying @@ -143,7 +149,7 @@ pub mod common { .expect("unable to update test data"); } - pub fn exec_psql(conn_string: &str, sql_content: &str) { + pub fn exec_psql(conn_string: &str, sql_content: &str) -> Result<(), String> { let output = Command::new("psql") .arg(conn_string) .arg("-c") @@ -155,10 +161,12 @@ pub mod common { "failed to execute SQL: {}", String::from_utf8_lossy(&output.stderr) ); - panic!( + Err(format!( "failed to execute SQL: {}", String::from_utf8_lossy(&output.stderr) - ); + )) + } else { + Ok(()) } } }