From eb5aa475db7b7d5b6074720be858b00059071545 Mon Sep 17 00:00:00 2001 From: Luther Monson Date: Sun, 3 May 2026 14:12:29 -0700 Subject: [PATCH 1/7] add unit tests for litewire-tds crate Add 51 tests across all three source files covering PacketType::from_u8, sqlite_to_tds_type, value_to_tds_type, build_columns, decode_utf16le, TdsSession transaction state management, and skip_all_headers. --- crates/litewire-tds/src/handler.rs | 167 ++++++++++++++++++++++ crates/litewire-tds/src/packet.rs | 57 ++++++++ crates/litewire-tds/src/token.rs | 214 +++++++++++++++++++++++++++++ 3 files changed, 438 insertions(+) diff --git a/crates/litewire-tds/src/handler.rs b/crates/litewire-tds/src/handler.rs index 3e6b36f..76fdcfc 100644 --- a/crates/litewire-tds/src/handler.rs +++ b/crates/litewire-tds/src/handler.rs @@ -512,3 +512,170 @@ async fn send_error( token::write_done(&mut resp, token::DONE_FINAL, 0); packet::write_message(stream, PacketType::Response, &resp, DEFAULT_PACKET_SIZE).await } + +#[cfg(test)] +mod tests { + use super::*; + + // ── decode_utf16le ───────────────────────────────────────────────────── + + #[test] + fn decode_utf16le_ascii() { + // "Hello" as UTF-16LE + let data: Vec = "Hello" + .encode_utf16() + .flat_map(|c| c.to_le_bytes()) + .collect(); + assert_eq!(decode_utf16le(&data), Some("Hello".to_string())); + } + + #[test] + fn decode_utf16le_unicode() { + // "café" contains a non-ASCII character + let data: Vec = "caf\u{00E9}" + .encode_utf16() + .flat_map(|c| c.to_le_bytes()) + .collect(); + assert_eq!(decode_utf16le(&data), Some("caf\u{00E9}".to_string())); + } + + #[test] + fn decode_utf16le_emoji() { + // Emoji requires a surrogate pair in UTF-16 + let data: Vec = "\u{1F600}" + .encode_utf16() + .flat_map(|c| c.to_le_bytes()) + .collect(); + assert_eq!(decode_utf16le(&data), Some("\u{1F600}".to_string())); + } + + #[test] + fn decode_utf16le_empty() { + assert_eq!(decode_utf16le(&[]), Some(String::new())); + } + + #[test] + fn decode_utf16le_odd_length_returns_none() { + // Odd number of bytes cannot be valid UTF-16LE + assert_eq!(decode_utf16le(&[0x41]), None); + assert_eq!(decode_utf16le(&[0x41, 0x00, 0x42]), None); + } + + // ── TdsSession ───────────────────────────────────────────────────────── + + #[test] + fn session_new_initial_state() { + let session = TdsSession::new(); + assert!(!session.in_transaction); + assert_eq!(session.current_tran_id, 0); + assert_eq!(session.next_tran_id, 1); + } + + #[test] + fn session_begin_sets_transaction() { + let mut session = TdsSession::new(); + let id = session.begin(); + assert_eq!(id, 1); + assert!(session.in_transaction); + assert_eq!(session.current_tran_id, 1); + } + + #[test] + fn session_end_clears_transaction() { + let mut session = TdsSession::new(); + session.begin(); + let old_id = session.end(); + assert_eq!(old_id, 1); + assert!(!session.in_transaction); + assert_eq!(session.current_tran_id, 0); + } + + #[test] + fn session_multiple_begin_end_cycles() { + let mut session = TdsSession::new(); + + let id1 = session.begin(); + assert_eq!(id1, 1); + let old1 = session.end(); + assert_eq!(old1, 1); + + let id2 = session.begin(); + assert_eq!(id2, 2); + assert!(session.in_transaction); + assert_eq!(session.current_tran_id, 2); + + let old2 = session.end(); + assert_eq!(old2, 2); + assert!(!session.in_transaction); + + let id3 = session.begin(); + assert_eq!(id3, 3); + } + + #[test] + fn session_next_tran_id_increments() { + let mut session = TdsSession::new(); + session.begin(); + assert_eq!(session.next_tran_id, 2); + session.end(); + session.begin(); + assert_eq!(session.next_tran_id, 3); + } + + // ── skip_all_headers ─────────────────────────────────────────────────── + + #[test] + fn skip_all_headers_valid() { + // total_length = 10 (stored as u32 LE at the start), then 6 bytes of header data + let mut payload = vec![0u8; 20]; + payload[0..4].copy_from_slice(&10u32.to_le_bytes()); + assert_eq!(skip_all_headers(&payload), 10); + } + + #[test] + fn skip_all_headers_exact_payload_length() { + // total_length equals the entire payload length + let mut payload = vec![0u8; 8]; + payload[0..4].copy_from_slice(&8u32.to_le_bytes()); + assert_eq!(skip_all_headers(&payload), 8); + } + + #[test] + fn skip_all_headers_minimum_valid() { + // total_length = 4 (minimum valid: just the length field itself) + let mut payload = vec![0u8; 10]; + payload[0..4].copy_from_slice(&4u32.to_le_bytes()); + assert_eq!(skip_all_headers(&payload), 4); + } + + #[test] + fn skip_all_headers_payload_shorter_than_4_bytes() { + assert_eq!(skip_all_headers(&[]), 0); + assert_eq!(skip_all_headers(&[0x01]), 0); + assert_eq!(skip_all_headers(&[0x01, 0x02]), 0); + assert_eq!(skip_all_headers(&[0x01, 0x02, 0x03]), 0); + } + + #[test] + fn skip_all_headers_total_length_exceeds_payload() { + // total_length claims 100 bytes but payload is only 10 + let mut payload = vec![0u8; 10]; + payload[0..4].copy_from_slice(&100u32.to_le_bytes()); + assert_eq!(skip_all_headers(&payload), 0); + } + + #[test] + fn skip_all_headers_total_length_less_than_4() { + // total_length = 3 is invalid (< 4), should return 0 + let mut payload = vec![0u8; 10]; + payload[0..4].copy_from_slice(&3u32.to_le_bytes()); + assert_eq!(skip_all_headers(&payload), 0); + } + + #[test] + fn skip_all_headers_total_length_zero() { + let mut payload = vec![0u8; 10]; + payload[0..4].copy_from_slice(&0u32.to_le_bytes()); + assert_eq!(skip_all_headers(&payload), 0); + } +} diff --git a/crates/litewire-tds/src/packet.rs b/crates/litewire-tds/src/packet.rs index 8eac8db..6df8e09 100644 --- a/crates/litewire-tds/src/packet.rs +++ b/crates/litewire-tds/src/packet.rs @@ -142,3 +142,60 @@ pub async fn write_message( writer.flush().await } + +#[cfg(test)] +mod tests { + use super::*; + + // ── PacketType::from_u8 ──────────────────────────────────────────────── + + #[test] + fn packet_type_prelogin() { + assert_eq!(PacketType::from_u8(0x12), Some(PacketType::PreLogin)); + } + + #[test] + fn packet_type_login7() { + assert_eq!(PacketType::from_u8(0x10), Some(PacketType::Login7)); + } + + #[test] + fn packet_type_sql_batch() { + assert_eq!(PacketType::from_u8(0x01), Some(PacketType::SqlBatch)); + } + + #[test] + fn packet_type_rpc_request() { + assert_eq!(PacketType::from_u8(0x03), Some(PacketType::RpcRequest)); + } + + #[test] + fn packet_type_response() { + assert_eq!(PacketType::from_u8(0x04), Some(PacketType::Response)); + } + + #[test] + fn packet_type_invalid_zero() { + assert_eq!(PacketType::from_u8(0x00), None); + } + + #[test] + fn packet_type_invalid_0x02() { + assert_eq!(PacketType::from_u8(0x02), None); + } + + #[test] + fn packet_type_invalid_0xff() { + assert_eq!(PacketType::from_u8(0xFF), None); + } + + // ── PacketType repr round-trip ───────────────────────────────────────── + + #[test] + fn packet_type_round_trip() { + for byte in [0x01u8, 0x03, 0x04, 0x10, 0x12] { + let pt = PacketType::from_u8(byte).unwrap(); + assert_eq!(pt as u8, byte); + } + } +} diff --git a/crates/litewire-tds/src/token.rs b/crates/litewire-tds/src/token.rs index f3b5b73..4af42be 100644 --- a/crates/litewire-tds/src/token.rs +++ b/crates/litewire-tds/src/token.rs @@ -418,3 +418,217 @@ pub fn write_done(buf: &mut BytesMut, status: u16, row_count: u64) { buf.put_u16_le(0); // CurCmd buf.put_u64_le(row_count); // DoneRowCount (8 bytes in TDS 7.2+) } + +#[cfg(test)] +mod tests { + use super::*; + use litewire_backend::Column; + + // ── sqlite_to_tds_type ───────────────────────────────────────────────── + + #[test] + fn sqlite_type_none_defaults_to_nvarchar() { + assert!(matches!(sqlite_to_tds_type(None), TdsType::NVarChar)); + } + + #[test] + fn sqlite_type_integer() { + assert!(matches!(sqlite_to_tds_type(Some("INTEGER")), TdsType::BigInt)); + } + + #[test] + fn sqlite_type_int() { + assert!(matches!(sqlite_to_tds_type(Some("INT")), TdsType::BigInt)); + } + + #[test] + fn sqlite_type_bigint() { + assert!(matches!(sqlite_to_tds_type(Some("BIGINT")), TdsType::BigInt)); + } + + #[test] + fn sqlite_type_boolean() { + assert!(matches!(sqlite_to_tds_type(Some("BOOLEAN")), TdsType::BigInt)); + } + + #[test] + fn sqlite_type_bit() { + assert!(matches!(sqlite_to_tds_type(Some("BIT")), TdsType::BigInt)); + } + + #[test] + fn sqlite_type_real() { + assert!(matches!(sqlite_to_tds_type(Some("REAL")), TdsType::Float8)); + } + + #[test] + fn sqlite_type_float() { + assert!(matches!(sqlite_to_tds_type(Some("FLOAT")), TdsType::Float8)); + } + + #[test] + fn sqlite_type_double() { + assert!(matches!(sqlite_to_tds_type(Some("DOUBLE")), TdsType::Float8)); + } + + #[test] + fn sqlite_type_text() { + assert!(matches!(sqlite_to_tds_type(Some("TEXT")), TdsType::NVarChar)); + } + + #[test] + fn sqlite_type_varchar() { + assert!(matches!(sqlite_to_tds_type(Some("VARCHAR")), TdsType::NVarChar)); + } + + #[test] + fn sqlite_type_blob() { + assert!(matches!(sqlite_to_tds_type(Some("BLOB")), TdsType::VarBinary)); + } + + #[test] + fn sqlite_type_binary() { + assert!(matches!(sqlite_to_tds_type(Some("BINARY")), TdsType::VarBinary)); + } + + #[test] + fn sqlite_type_case_insensitive() { + assert!(matches!(sqlite_to_tds_type(Some("integer")), TdsType::BigInt)); + assert!(matches!(sqlite_to_tds_type(Some("Real")), TdsType::Float8)); + assert!(matches!(sqlite_to_tds_type(Some("blob")), TdsType::VarBinary)); + assert!(matches!(sqlite_to_tds_type(Some("text")), TdsType::NVarChar)); + } + + #[test] + fn sqlite_type_unknown_defaults_to_nvarchar() { + assert!(matches!(sqlite_to_tds_type(Some("DATE")), TdsType::NVarChar)); + assert!(matches!(sqlite_to_tds_type(Some("TIMESTAMP")), TdsType::NVarChar)); + assert!(matches!(sqlite_to_tds_type(Some("JSON")), TdsType::NVarChar)); + } + + // ── value_to_tds_type ────────────────────────────────────────────────── + + #[test] + fn value_type_null() { + assert!(matches!(value_to_tds_type(&Value::Null), TdsType::NVarChar)); + } + + #[test] + fn value_type_integer() { + assert!(matches!(value_to_tds_type(&Value::Integer(42)), TdsType::BigInt)); + } + + #[test] + fn value_type_float() { + assert!(matches!(value_to_tds_type(&Value::Float(3.14)), TdsType::Float8)); + } + + #[test] + fn value_type_text() { + assert!(matches!( + value_to_tds_type(&Value::Text("hello".into())), + TdsType::NVarChar + )); + } + + #[test] + fn value_type_blob() { + assert!(matches!( + value_to_tds_type(&Value::Blob(vec![1, 2, 3])), + TdsType::VarBinary + )); + } + + // ── build_columns ────────────────────────────────────────────────────── + + #[test] + fn build_columns_empty() { + let cols = build_columns(&[], None); + assert!(cols.is_empty()); + } + + #[test] + fn build_columns_with_declared_types() { + let columns = vec![ + Column { + name: "id".into(), + decltype: Some("INTEGER".into()), + }, + Column { + name: "name".into(), + decltype: Some("TEXT".into()), + }, + Column { + name: "score".into(), + decltype: Some("REAL".into()), + }, + Column { + name: "data".into(), + decltype: Some("BLOB".into()), + }, + ]; + let tds_cols = build_columns(&columns, None); + assert_eq!(tds_cols.len(), 4); + assert_eq!(tds_cols[0].name, "id"); + assert!(matches!(tds_cols[0].tds_type, TdsType::BigInt)); + assert_eq!(tds_cols[1].name, "name"); + assert!(matches!(tds_cols[1].tds_type, TdsType::NVarChar)); + assert_eq!(tds_cols[2].name, "score"); + assert!(matches!(tds_cols[2].tds_type, TdsType::Float8)); + assert_eq!(tds_cols[3].name, "data"); + assert!(matches!(tds_cols[3].tds_type, TdsType::VarBinary)); + } + + #[test] + fn build_columns_no_decltype_with_first_row() { + let columns = vec![ + Column { + name: "a".into(), + decltype: None, + }, + Column { + name: "b".into(), + decltype: None, + }, + Column { + name: "c".into(), + decltype: None, + }, + ]; + let first_row = vec![ + Value::Integer(1), + Value::Float(2.5), + Value::Text("hi".into()), + ]; + let tds_cols = build_columns(&columns, Some(&first_row)); + assert_eq!(tds_cols.len(), 3); + assert!(matches!(tds_cols[0].tds_type, TdsType::BigInt)); + assert!(matches!(tds_cols[1].tds_type, TdsType::Float8)); + assert!(matches!(tds_cols[2].tds_type, TdsType::NVarChar)); + } + + #[test] + fn build_columns_no_decltype_no_first_row() { + let columns = vec![ + Column { + name: "x".into(), + decltype: None, + }, + ]; + let tds_cols = build_columns(&columns, None); + assert_eq!(tds_cols.len(), 1); + assert!(matches!(tds_cols[0].tds_type, TdsType::NVarChar)); + } + + #[test] + fn build_columns_decltype_takes_precedence_over_first_row() { + let columns = vec![Column { + name: "val".into(), + decltype: Some("INTEGER".into()), + }]; + // Even if the first row has a float, the declared type wins. + let first_row = vec![Value::Float(1.0)]; + let tds_cols = build_columns(&columns, Some(&first_row)); + assert!(matches!(tds_cols[0].tds_type, TdsType::BigInt)); + } +} From c172683d7bb9e7f9c6b11e14d43df675cd15be40 Mon Sep 17 00:00:00 2001 From: Luther Monson Date: Sun, 3 May 2026 14:13:34 -0700 Subject: [PATCH 2/7] fix: sanitize identifiers to prevent SQL injection in metadata queries Add sanitize_identifier() that strips non-alphanumeric/underscore/period characters from table and schema names before they are interpolated into generated SQL. Applied at extraction time in extract_table_name(), extract_where_value_original(), and sp_columns parsing so all MetadataQuery variants are protected. --- crates/litewire-translate/src/metadata.rs | 105 ++++++++++++++++++++-- 1 file changed, 96 insertions(+), 9 deletions(-) diff --git a/crates/litewire-translate/src/metadata.rs b/crates/litewire-translate/src/metadata.rs index 4c8361f..3d938b0 100644 --- a/crates/litewire-translate/src/metadata.rs +++ b/crates/litewire-translate/src/metadata.rs @@ -299,10 +299,11 @@ pub fn detect_metadata_query(sql: &str, _dialect: Dialect) -> Option Option String { + name.chars() + .filter(|c| c.is_alphanumeric() || *c == '_' || *c == '.') + .collect() +} + /// Extract a simple `column = 'value'` filter from a WHERE clause. /// Searches on the uppercased SQL for the column name, but extracts the /// value from the original SQL to preserve case. @@ -338,13 +348,12 @@ fn extract_where_value_original(original_sql: &str, column: &str) -> Option String { let orig_word = orig_rest.split_ascii_whitespace().next().unwrap_or(word); // Strip backticks, double quotes, brackets. - orig_word + let raw = orig_word .trim_matches('`') .trim_matches('"') .trim_matches('[') .trim_end_matches(']') - .trim_end_matches(';') - .to_string() + .trim_end_matches(';'); + + sanitize_identifier(raw) } #[cfg(test)] @@ -912,4 +922,81 @@ mod tests { "got: {q:?}" ); } + + // ── SQL injection sanitization ──────────────────────────────────────── + + #[test] + fn table_name_injection_sanitized() { + let q = detect_metadata_query( + "DESCRIBE users'; DROP TABLE x; --", + Dialect::MySQL, + ); + match q { + Some(MetadataQuery::ShowColumns { table }) => { + assert!(!table.contains('\''), "quote not stripped: {table}"); + assert!(!table.contains(';'), "semicolon not stripped: {table}"); + assert!(!table.contains('-'), "dash not stripped: {table}"); + } + other => panic!("should detect as ShowColumns, got: {other:?}"), + } + } + + #[test] + fn schema_filter_injection_sanitized() { + let q = detect_metadata_query( + "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'main'; DROP TABLE x'", + Dialect::MySQL, + ); + match q { + Some(MetadataQuery::InformationSchemaTables { schema_filter }) => { + if let Some(schema) = &schema_filter { + assert!(!schema.contains(';'), "semicolon not stripped: {schema}"); + assert!(!schema.contains('\''), "quote not stripped: {schema}"); + assert!(!schema.contains('-'), "dash not stripped: {schema}"); + } + } + other => panic!("expected InformationSchemaTables, got: {other:?}"), + } + } + + #[test] + fn normal_identifiers_unchanged() { + // Simple underscore name + let q = detect_metadata_query("DESCRIBE my_table", Dialect::MySQL); + assert!( + matches!(q, Some(MetadataQuery::ShowColumns { ref table }) if table == "my_table"), + "got: {q:?}" + ); + + // Mixed case with digits + let q = detect_metadata_query("DESCRIBE MyTable123", Dialect::MySQL); + assert!( + matches!(q, Some(MetadataQuery::ShowColumns { ref table }) if table == "MyTable123"), + "got: {q:?}" + ); + + // Dot-separated (schema.table) + let q = detect_metadata_query("DESCRIBE schema.table", Dialect::MySQL); + assert!( + matches!(q, Some(MetadataQuery::ShowColumns { ref table }) if table == "schema.table"), + "got: {q:?}" + ); + } + + #[test] + fn generated_sql_safe_from_injection() { + let q = detect_metadata_query( + "DESCRIBE users'; DROP TABLE sqlite_master; --", + Dialect::MySQL, + ); + let sql = q.unwrap().to_sqlite_sql(); + assert!( + !sql.contains(';'), + "generated SQL contains semicolon: {sql}" + ); + assert!( + !sql.contains("--"), + "generated SQL contains comment marker: {sql}" + ); + } } From 1321db48406e203f9b9b50947803a5e0b61edccf Mon Sep 17 00:00:00 2001 From: Luther Monson Date: Sun, 3 May 2026 14:14:15 -0700 Subject: [PATCH 3/7] Add unit tests for litewire-backend edge cases Cover edge case values (integer/float extremes, empty text/blob, large payloads, Unicode/emoji), error paths (nonexistent table, wrong param count, SQL syntax errors), mixed-type rows, sequential query-execute patterns, and hrana client conversion fallbacks (non-numeric integer parse, invalid base64, overflow) plus URL trailing-slash handling. --- crates/litewire-backend/src/hrana_client.rs | 287 ++++++++++++++ .../litewire-backend/src/rusqlite_backend.rs | 349 ++++++++++++++++++ 2 files changed, 636 insertions(+) diff --git a/crates/litewire-backend/src/hrana_client.rs b/crates/litewire-backend/src/hrana_client.rs index fa061f4..757985a 100644 --- a/crates/litewire-backend/src/hrana_client.rs +++ b/crates/litewire-backend/src/hrana_client.rs @@ -428,4 +428,291 @@ mod tests { other => panic!("expected Other error, got: {other:?}"), } } + + // ── value_to_hrana: Float (was missing) ───────────────────────────────── + + #[test] + fn value_to_hrana_float() { + match value_to_hrana(&Value::Float(3.14)) { + HranaValue::Float { value } => { + assert!((value - 3.14).abs() < f64::EPSILON); + } + other => panic!("expected Float, got: {other:?}"), + } + } + + // ── value_to_hrana: edge cases ────────────────────────────────────────── + + #[test] + fn value_to_hrana_integer_max() { + match value_to_hrana(&Value::Integer(i64::MAX)) { + HranaValue::Integer { value } => assert_eq!(value, i64::MAX.to_string()), + other => panic!("expected Integer, got: {other:?}"), + } + } + + #[test] + fn value_to_hrana_integer_min() { + match value_to_hrana(&Value::Integer(i64::MIN)) { + HranaValue::Integer { value } => { + assert_eq!(value, i64::MIN.to_string()); + assert!(value.starts_with('-')); + } + other => panic!("expected Integer, got: {other:?}"), + } + } + + #[test] + fn value_to_hrana_integer_zero() { + match value_to_hrana(&Value::Integer(0)) { + HranaValue::Integer { value } => assert_eq!(value, "0"), + other => panic!("expected Integer, got: {other:?}"), + } + } + + #[test] + fn value_to_hrana_empty_text() { + match value_to_hrana(&Value::Text(String::new())) { + HranaValue::Text { value } => assert_eq!(value, ""), + other => panic!("expected Text, got: {other:?}"), + } + } + + #[test] + fn value_to_hrana_empty_blob() { + match value_to_hrana(&Value::Blob(vec![])) { + HranaValue::Blob { base64: b64 } => assert_eq!(b64, ""), + other => panic!("expected Blob, got: {other:?}"), + } + } + + // ── response_value_to_backend: edge cases ─────────────────────────────── + + #[test] + fn response_value_integer_non_numeric_defaults_to_zero() { + let rv = ResponseValue::Integer { + value: "abc".into(), + }; + let val = response_value_to_backend(&rv); + assert_eq!(val, Value::Integer(0)); + } + + #[test] + fn response_value_integer_overflow_defaults_to_zero() { + let rv = ResponseValue::Integer { + value: "99999999999999999999".into(), + }; + let val = response_value_to_backend(&rv); + assert_eq!(val, Value::Integer(0)); + } + + #[test] + fn response_value_integer_negative() { + let rv = ResponseValue::Integer { + value: "-42".into(), + }; + let val = response_value_to_backend(&rv); + assert_eq!(val, Value::Integer(-42)); + } + + #[test] + fn response_value_integer_empty_string_defaults_to_zero() { + let rv = ResponseValue::Integer { + value: String::new(), + }; + let val = response_value_to_backend(&rv); + assert_eq!(val, Value::Integer(0)); + } + + #[test] + fn response_value_blob_invalid_base64_defaults_to_empty() { + let rv = ResponseValue::Blob { + base64: "!!!not-valid-base64!!!".into(), + }; + let val = response_value_to_backend(&rv); + assert_eq!(val, Value::Blob(vec![])); + } + + #[test] + fn response_value_blob_valid_base64() { + use base64::Engine; + let data = vec![0xDE, 0xAD, 0xBE, 0xEF]; + let b64 = base64::engine::general_purpose::STANDARD.encode(&data); + let rv = ResponseValue::Blob { base64: b64 }; + let val = response_value_to_backend(&rv); + assert_eq!(val, Value::Blob(data)); + } + + #[test] + fn response_value_empty_text() { + let rv = ResponseValue::Text { + value: String::new(), + }; + let val = response_value_to_backend(&rv); + assert_eq!(val, Value::Text(String::new())); + } + + // ── execute_response_to_result_set: edge cases ────────────────────────── + + #[test] + fn execute_response_empty_rows_nonempty_columns() { + let exec = ExecuteResponse { + cols: vec![ + ColResponse { + name: "id".into(), + decltype: Some("INTEGER".into()), + }, + ColResponse { + name: "name".into(), + decltype: Some("TEXT".into()), + }, + ], + rows: vec![], + affected_row_count: 0, + last_insert_rowid: None, + }; + let rs = execute_response_to_result_set(exec); + assert_eq!(rs.columns.len(), 2); + assert!(rs.rows.is_empty()); + } + + #[test] + fn execute_response_empty_columns_and_rows() { + let exec = ExecuteResponse { + cols: vec![], + rows: vec![], + affected_row_count: 0, + last_insert_rowid: None, + }; + let rs = execute_response_to_result_set(exec); + assert!(rs.columns.is_empty()); + assert!(rs.rows.is_empty()); + } + + #[test] + fn execute_response_rows_with_null_values() { + let exec = ExecuteResponse { + cols: vec![ + ColResponse { + name: "a".into(), + decltype: None, + }, + ColResponse { + name: "b".into(), + decltype: None, + }, + ], + rows: vec![vec![ResponseValue::Null, ResponseValue::Null]], + affected_row_count: 0, + last_insert_rowid: None, + }; + let rs = execute_response_to_result_set(exec); + assert_eq!(rs.rows.len(), 1); + assert_eq!(rs.rows[0][0], Value::Null); + assert_eq!(rs.rows[0][1], Value::Null); + } + + #[test] + fn execute_response_mixed_value_types_in_row() { + let exec = ExecuteResponse { + cols: vec![ + ColResponse { + name: "a".into(), + decltype: None, + }, + ColResponse { + name: "b".into(), + decltype: None, + }, + ColResponse { + name: "c".into(), + decltype: None, + }, + ColResponse { + name: "d".into(), + decltype: None, + }, + ColResponse { + name: "e".into(), + decltype: None, + }, + ], + rows: vec![vec![ + ResponseValue::Null, + ResponseValue::Integer { + value: "7".into(), + }, + ResponseValue::Float { value: 2.5 }, + ResponseValue::Text { + value: "hello".into(), + }, + ResponseValue::Blob { + base64: "AAEC".into(), // [0, 1, 2] + }, + ]], + affected_row_count: 0, + last_insert_rowid: None, + }; + let rs = execute_response_to_result_set(exec); + assert_eq!(rs.rows.len(), 1); + assert_eq!(rs.rows[0][0], Value::Null); + assert_eq!(rs.rows[0][1], Value::Integer(7)); + assert_eq!(rs.rows[0][2], Value::Float(2.5)); + assert_eq!(rs.rows[0][3], Value::Text("hello".into())); + assert_eq!(rs.rows[0][4], Value::Blob(vec![0, 1, 2])); + } + + #[test] + fn execute_response_column_decltype_none() { + let exec = ExecuteResponse { + cols: vec![ColResponse { + name: "expr".into(), + decltype: None, + }], + rows: vec![vec![ResponseValue::Integer { + value: "1".into(), + }]], + affected_row_count: 0, + last_insert_rowid: None, + }; + let rs = execute_response_to_result_set(exec); + assert!(rs.columns[0].decltype.is_none()); + } + + // ── HranaClient::new() URL handling ───────────────────────────────────── + + #[test] + fn hrana_client_url_trailing_slash_trimmed() { + let client = HranaClient::new("http://localhost:8081/"); + assert_eq!(client.pipeline_url, "http://localhost:8081/v2/pipeline"); + assert_eq!(client.health_url, "http://localhost:8081/health"); + } + + #[test] + fn hrana_client_url_no_trailing_slash() { + let client = HranaClient::new("http://localhost:8081"); + assert_eq!(client.pipeline_url, "http://localhost:8081/v2/pipeline"); + assert_eq!(client.health_url, "http://localhost:8081/health"); + } + + #[test] + fn hrana_client_url_multiple_trailing_slashes_trimmed() { + let client = HranaClient::new("http://localhost:8081///"); + assert_eq!(client.pipeline_url, "http://localhost:8081/v2/pipeline"); + assert_eq!(client.health_url, "http://localhost:8081/health"); + } + + // ── hrana_error_to_backend: additional cases ──────────────────────────── + + #[test] + fn hrana_error_with_non_sqlite_code() { + let err = hrana_error_to_backend(ErrorResponse { + message: "something else".into(), + code: Some("INTERNAL_ERROR".into()), + }); + match err { + BackendError::Other(msg) => assert_eq!(msg, "something else"), + other => panic!("expected Other error, got: {other:?}"), + } + } } diff --git a/crates/litewire-backend/src/rusqlite_backend.rs b/crates/litewire-backend/src/rusqlite_backend.rs index de086d2..9abeb4e 100644 --- a/crates/litewire-backend/src/rusqlite_backend.rs +++ b/crates/litewire-backend/src/rusqlite_backend.rs @@ -405,4 +405,353 @@ mod tests { assert_eq!(rs.rows[0][0], Value::Integer(1)); assert_eq!(rs.rows[0][1], Value::Text("hello".into())); } + + // ── Concurrent / sequential operations ────────────────────────────────── + + #[tokio::test] + async fn query_execute_query_sequential() { + let backend = Rusqlite::memory().unwrap(); + backend + .execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v TEXT)", &[]) + .await + .unwrap(); + + // Query empty table + let rs = backend.query("SELECT * FROM t", &[]).await.unwrap(); + assert!(rs.rows.is_empty()); + + // Insert a row + backend + .execute( + "INSERT INTO t VALUES (?1, ?2)", + &[Value::Integer(1), Value::Text("first".into())], + ) + .await + .unwrap(); + + // Query again and see the new row + let rs = backend.query("SELECT * FROM t", &[]).await.unwrap(); + assert_eq!(rs.rows.len(), 1); + assert_eq!(rs.rows[0][1], Value::Text("first".into())); + + // Insert another row + backend + .execute( + "INSERT INTO t VALUES (?1, ?2)", + &[Value::Integer(2), Value::Text("second".into())], + ) + .await + .unwrap(); + + // Verify both rows exist + let rs = backend + .query("SELECT * FROM t ORDER BY id", &[]) + .await + .unwrap(); + assert_eq!(rs.rows.len(), 2); + } + + #[tokio::test] + async fn explicit_integer_primary_key_rowid() { + let backend = Rusqlite::memory().unwrap(); + backend + .execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v TEXT)", &[]) + .await + .unwrap(); + + // Insert with an explicit id of 42 + let result = backend + .execute( + "INSERT INTO t VALUES (?1, ?2)", + &[Value::Integer(42), Value::Text("at 42".into())], + ) + .await + .unwrap(); + assert_eq!(result.last_insert_rowid, Some(42)); + + // Insert with an explicit id of 100 + let result = backend + .execute( + "INSERT INTO t VALUES (?1, ?2)", + &[Value::Integer(100), Value::Text("at 100".into())], + ) + .await + .unwrap(); + assert_eq!(result.last_insert_rowid, Some(100)); + } + + // ── Edge case values ──────────────────────────────────────────────────── + + #[tokio::test] + async fn integer_extremes() { + let backend = Rusqlite::memory().unwrap(); + backend + .execute("CREATE TABLE t (v INTEGER)", &[]) + .await + .unwrap(); + + for val in [i64::MIN, i64::MAX, 0] { + backend + .execute("INSERT INTO t VALUES (?1)", &[Value::Integer(val)]) + .await + .unwrap(); + } + + let rs = backend.query("SELECT v FROM t ORDER BY rowid", &[]).await.unwrap(); + assert_eq!(rs.rows[0][0], Value::Integer(i64::MIN)); + assert_eq!(rs.rows[1][0], Value::Integer(i64::MAX)); + assert_eq!(rs.rows[2][0], Value::Integer(0)); + } + + #[tokio::test] + async fn float_extremes() { + let backend = Rusqlite::memory().unwrap(); + backend + .execute("CREATE TABLE t (v REAL)", &[]) + .await + .unwrap(); + + for val in [f64::MIN, f64::MAX, f64::MIN_POSITIVE] { + backend + .execute("INSERT INTO t VALUES (?1)", &[Value::Float(val)]) + .await + .unwrap(); + } + + let rs = backend.query("SELECT v FROM t ORDER BY rowid", &[]).await.unwrap(); + assert_eq!(rs.rows[0][0], Value::Float(f64::MIN)); + assert_eq!(rs.rows[1][0], Value::Float(f64::MAX)); + assert_eq!(rs.rows[2][0], Value::Float(f64::MIN_POSITIVE)); + } + + #[tokio::test] + async fn empty_text_roundtrip() { + let backend = Rusqlite::memory().unwrap(); + backend + .execute("CREATE TABLE t (v TEXT)", &[]) + .await + .unwrap(); + backend + .execute("INSERT INTO t VALUES (?1)", &[Value::Text(String::new())]) + .await + .unwrap(); + + let rs = backend.query("SELECT v FROM t", &[]).await.unwrap(); + assert_eq!(rs.rows[0][0], Value::Text(String::new())); + } + + #[tokio::test] + async fn empty_blob_roundtrip() { + let backend = Rusqlite::memory().unwrap(); + backend + .execute("CREATE TABLE t (v BLOB)", &[]) + .await + .unwrap(); + backend + .execute("INSERT INTO t VALUES (?1)", &[Value::Blob(vec![])]) + .await + .unwrap(); + + let rs = backend.query("SELECT v FROM t", &[]).await.unwrap(); + assert_eq!(rs.rows[0][0], Value::Blob(vec![])); + } + + #[tokio::test] + async fn large_text_roundtrip() { + let backend = Rusqlite::memory().unwrap(); + backend + .execute("CREATE TABLE t (v TEXT)", &[]) + .await + .unwrap(); + + let large = "x".repeat(10_000); + backend + .execute( + "INSERT INTO t VALUES (?1)", + &[Value::Text(large.clone())], + ) + .await + .unwrap(); + + let rs = backend.query("SELECT v FROM t", &[]).await.unwrap(); + assert_eq!(rs.rows[0][0], Value::Text(large)); + } + + #[tokio::test] + async fn large_blob_roundtrip() { + let backend = Rusqlite::memory().unwrap(); + backend + .execute("CREATE TABLE t (v BLOB)", &[]) + .await + .unwrap(); + + let large: Vec = (0..10_000).map(|i| (i % 256) as u8).collect(); + backend + .execute( + "INSERT INTO t VALUES (?1)", + &[Value::Blob(large.clone())], + ) + .await + .unwrap(); + + let rs = backend.query("SELECT v FROM t", &[]).await.unwrap(); + assert_eq!(rs.rows[0][0], Value::Blob(large)); + } + + #[tokio::test] + async fn text_with_special_characters() { + let backend = Rusqlite::memory().unwrap(); + backend + .execute("CREATE TABLE t (v TEXT)", &[]) + .await + .unwrap(); + + let specials = vec![ + "hello\0world", // null byte + "こんにちは", // Japanese + "🎉🚀💯", // emoji + "café résumé naïve", // accented + "line1\nline2\ttab", // newline and tab + "'quotes' and \"double\"", // quotes + ]; + + for s in &specials { + backend + .execute( + "INSERT INTO t VALUES (?1)", + &[Value::Text(s.to_string())], + ) + .await + .unwrap(); + } + + let rs = backend.query("SELECT v FROM t ORDER BY rowid", &[]).await.unwrap(); + for (i, s) in specials.iter().enumerate() { + assert_eq!(rs.rows[i][0], Value::Text(s.to_string())); + } + } + + // ── Mixed types in a single row ───────────────────────────────────────── + + #[tokio::test] + async fn insert_all_types_in_one_row() { + let backend = Rusqlite::memory().unwrap(); + backend + .execute( + "CREATE TABLE t (a, b INTEGER, c REAL, d TEXT, e BLOB)", + &[], + ) + .await + .unwrap(); + + backend + .execute( + "INSERT INTO t VALUES (?1, ?2, ?3, ?4, ?5)", + &[ + Value::Null, + Value::Integer(99), + Value::Float(1.5), + Value::Text("mixed".into()), + Value::Blob(vec![0xAB, 0xCD]), + ], + ) + .await + .unwrap(); + + let rs = backend.query("SELECT * FROM t", &[]).await.unwrap(); + assert_eq!(rs.rows.len(), 1); + assert_eq!(rs.rows[0][0], Value::Null); + assert_eq!(rs.rows[0][1], Value::Integer(99)); + assert_eq!(rs.rows[0][2], Value::Float(1.5)); + assert_eq!(rs.rows[0][3], Value::Text("mixed".into())); + assert_eq!(rs.rows[0][4], Value::Blob(vec![0xAB, 0xCD])); + } + + // ── Error cases ───────────────────────────────────────────────────────── + + #[tokio::test] + async fn query_nonexistent_table() { + let backend = Rusqlite::memory().unwrap(); + let result = backend.query("SELECT * FROM nonexistent_table", &[]).await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.to_string().contains("no such table"), + "expected 'no such table' in error: {err}" + ); + } + + #[tokio::test] + async fn execute_wrong_param_count() { + let backend = Rusqlite::memory().unwrap(); + backend + .execute("CREATE TABLE t (a INTEGER, b TEXT)", &[]) + .await + .unwrap(); + + // Provide 3 params when only 2 are expected + let result = backend + .execute( + "INSERT INTO t VALUES (?1, ?2)", + &[ + Value::Integer(1), + Value::Text("hello".into()), + Value::Integer(99), + ], + ) + .await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn execute_sql_syntax_error() { + let backend = Rusqlite::memory().unwrap(); + let result = backend.execute("INSERT INTO VALUES ()", &[]).await; + assert!(result.is_err()); + } + + // ── UPDATE and DELETE affected_rows ────────────────────────────────────── + + #[tokio::test] + async fn update_affected_rows() { + let backend = Rusqlite::memory().unwrap(); + backend + .execute("CREATE TABLE t (id INTEGER, v TEXT)", &[]) + .await + .unwrap(); + + for i in 0..3 { + backend + .execute( + "INSERT INTO t VALUES (?1, ?2)", + &[Value::Integer(i), Value::Text("same".into())], + ) + .await + .unwrap(); + } + + let result = backend + .execute( + "UPDATE t SET v = ?1 WHERE v = ?2", + &[Value::Text("changed".into()), Value::Text("same".into())], + ) + .await + .unwrap(); + assert_eq!(result.affected_rows, 3); + } + + #[tokio::test] + async fn delete_no_matching_rows() { + let backend = Rusqlite::memory().unwrap(); + backend + .execute("CREATE TABLE t (id INTEGER)", &[]) + .await + .unwrap(); + + let result = backend + .execute("DELETE FROM t WHERE id = ?1", &[Value::Integer(999)]) + .await + .unwrap(); + assert_eq!(result.affected_rows, 0); + } } From 285e416d193eea38dfad14cf57b2aee282fabb48 Mon Sep 17 00:00:00 2001 From: Luther Monson Date: Sun, 3 May 2026 14:14:54 -0700 Subject: [PATCH 4/7] add comprehensive unit tests for emit.rs SQL output stage Adds 70+ roundtrip tests via translate() covering complex SELECTs (GROUP BY, HAVING, DISTINCT, LIMIT/OFFSET, aggregates, CASE WHEN, subqueries), JOINs (INNER, LEFT, RIGHT, CROSS, self-join, multi-join), set operations (UNION, INTERSECT, EXCEPT), CTEs (simple, multiple, recursive), complex DML (INSERT...SELECT, multi-row INSERT, UPDATE/DELETE with subqueries), DDL (CREATE/DROP TABLE IF [NOT] EXISTS, CREATE INDEX, CREATE UNIQUE INDEX), and transaction statements (BEGIN, COMMIT, ROLLBACK, SAVEPOINT, RELEASE, ROLLBACK TO). Tests cover MySQL, PostgreSQL, and TDS dialects. --- crates/litewire-translate/src/emit.rs | 909 ++++++++++++++++++++++++++ 1 file changed, 909 insertions(+) diff --git a/crates/litewire-translate/src/emit.rs b/crates/litewire-translate/src/emit.rs index 97fe023..d514e68 100644 --- a/crates/litewire-translate/src/emit.rs +++ b/crates/litewire-translate/src/emit.rs @@ -95,4 +95,913 @@ mod tests { let emitted = emit_statement(&stmts[0]); assert!(emitted.contains("SELECT user_id FROM orders")); } + + // ── Roundtrip tests via translate() ──────────────────────────────────── + + use crate::{translate, Dialect, TranslateResult}; + + /// Helper: translate SQL via a given dialect and return the emitted SQL string. + fn translate_sql(input: &str, dialect: Dialect) -> String { + let results = translate(input, dialect).unwrap(); + match &results[0] { + TranslateResult::Sql(s) => s.clone(), + other => panic!("expected Sql, got: {other:?}"), + } + } + + /// Helper: assert the emitted SQL contains a given fragment (case-insensitive). + fn assert_contains(sql: &str, fragment: &str) { + let upper = sql.to_ascii_uppercase(); + assert!( + upper.contains(&fragment.to_ascii_uppercase()), + "expected '{fragment}' in: {sql}" + ); + } + + // ── 1. Complex SELECT statements ─────────────────────────────────────── + + #[test] + fn select_group_by_having() { + let sql = translate_sql( + "SELECT department, COUNT(*) AS cnt FROM employees GROUP BY department HAVING COUNT(*) > 5", + Dialect::MySQL, + ); + assert_contains(&sql, "GROUP BY"); + assert_contains(&sql, "HAVING"); + assert_contains(&sql, "COUNT(*)"); + } + + #[test] + fn select_order_by_multiple_columns() { + let sql = translate_sql( + "SELECT id, name, age FROM users ORDER BY age DESC, name ASC", + Dialect::MySQL, + ); + assert_contains(&sql, "ORDER BY"); + assert_contains(&sql, "DESC"); + assert_contains(&sql, "ASC"); + } + + #[test] + fn select_distinct() { + let sql = translate_sql( + "SELECT DISTINCT department FROM employees", + Dialect::MySQL, + ); + assert_contains(&sql, "DISTINCT"); + assert_contains(&sql, "department"); + } + + #[test] + fn select_limit_offset_standard() { + let sql = translate_sql( + "SELECT * FROM users LIMIT 10 OFFSET 20", + Dialect::PostgreSQL, + ); + assert_contains(&sql, "LIMIT 10"); + assert_contains(&sql, "OFFSET 20"); + } + + #[test] + fn select_aggregate_count() { + let sql = translate_sql("SELECT COUNT(*) FROM users", Dialect::MySQL); + assert_contains(&sql, "COUNT(*)"); + } + + #[test] + fn select_aggregate_sum() { + let sql = translate_sql( + "SELECT SUM(amount) FROM orders", + Dialect::MySQL, + ); + assert_contains(&sql, "SUM("); + assert_contains(&sql, "amount"); + } + + #[test] + fn select_aggregate_avg() { + let sql = translate_sql( + "SELECT AVG(price) FROM products", + Dialect::MySQL, + ); + assert_contains(&sql, "AVG("); + assert_contains(&sql, "price"); + } + + #[test] + fn select_aggregate_min_max() { + let sql = translate_sql( + "SELECT MIN(created_at), MAX(created_at) FROM events", + Dialect::MySQL, + ); + assert_contains(&sql, "MIN("); + assert_contains(&sql, "MAX("); + } + + #[test] + fn select_case_when() { + let sql = translate_sql( + "SELECT CASE WHEN status = 'active' THEN 'yes' WHEN status = 'inactive' THEN 'no' ELSE 'unknown' END AS label FROM users", + Dialect::MySQL, + ); + assert_contains(&sql, "CASE"); + assert_contains(&sql, "WHEN"); + assert_contains(&sql, "THEN"); + assert_contains(&sql, "ELSE"); + assert_contains(&sql, "END"); + } + + #[test] + fn select_subquery_in_where() { + let sql = translate_sql( + "SELECT name FROM users WHERE id IN (SELECT user_id FROM orders WHERE total > 100)", + Dialect::MySQL, + ); + assert_contains(&sql, "IN (SELECT"); + assert_contains(&sql, "user_id"); + assert_contains(&sql, "total > 100"); + } + + #[test] + fn select_subquery_in_from() { + let sql = translate_sql( + "SELECT sub.name FROM (SELECT name FROM users WHERE active = 1) AS sub", + Dialect::MySQL, + ); + assert_contains(&sql, "sub"); + assert_contains(&sql, "SELECT name FROM users"); + } + + #[test] + fn select_with_alias() { + let sql = translate_sql( + "SELECT u.name AS user_name, COUNT(o.id) AS order_count FROM users u LEFT JOIN orders o ON u.id = o.user_id GROUP BY u.name", + Dialect::MySQL, + ); + assert_contains(&sql, "AS user_name"); + assert_contains(&sql, "AS order_count"); + assert_contains(&sql, "GROUP BY"); + } + + #[test] + fn select_between() { + let sql = translate_sql( + "SELECT * FROM events WHERE created_at BETWEEN '2024-01-01' AND '2024-12-31'", + Dialect::MySQL, + ); + assert_contains(&sql, "BETWEEN"); + assert_contains(&sql, "2024-01-01"); + assert_contains(&sql, "2024-12-31"); + } + + #[test] + fn select_is_null_is_not_null() { + let sql = translate_sql( + "SELECT * FROM users WHERE email IS NOT NULL AND phone IS NULL", + Dialect::MySQL, + ); + assert_contains(&sql, "IS NOT NULL"); + assert_contains(&sql, "IS NULL"); + } + + #[test] + fn select_like() { + let sql = translate_sql( + "SELECT * FROM users WHERE name LIKE '%Alice%'", + Dialect::MySQL, + ); + assert_contains(&sql, "LIKE"); + assert_contains(&sql, "%Alice%"); + } + + #[test] + fn select_in_list() { + let sql = translate_sql( + "SELECT * FROM users WHERE status IN ('active', 'pending', 'trial')", + Dialect::MySQL, + ); + assert_contains(&sql, "IN ("); + assert_contains(&sql, "'active'"); + assert_contains(&sql, "'pending'"); + assert_contains(&sql, "'trial'"); + } + + #[test] + fn select_exists_subquery() { + let sql = translate_sql( + "SELECT * FROM users u WHERE EXISTS (SELECT 1 FROM orders o WHERE o.user_id = u.id)", + Dialect::MySQL, + ); + assert_contains(&sql, "EXISTS"); + assert_contains(&sql, "SELECT 1"); + } + + #[test] + fn select_count_distinct() { + let sql = translate_sql( + "SELECT COUNT(DISTINCT department) FROM employees", + Dialect::MySQL, + ); + assert_contains(&sql, "COUNT(DISTINCT"); + assert_contains(&sql, "department"); + } + + // ── 2. JOINs ─────────────────────────────────────────────────────────── + + #[test] + fn inner_join_with_on() { + let sql = translate_sql( + "SELECT u.name, o.total FROM users u INNER JOIN orders o ON u.id = o.user_id", + Dialect::MySQL, + ); + assert_contains(&sql, "INNER JOIN"); + // Some dialects emit just JOIN; check for ON clause too. + assert_contains(&sql, "ON u.id = o.user_id"); + } + + #[test] + fn left_join() { + let sql = translate_sql( + "SELECT u.name, o.total FROM users u LEFT JOIN orders o ON u.id = o.user_id", + Dialect::MySQL, + ); + assert_contains(&sql, "LEFT JOIN"); + assert_contains(&sql, "ON u.id = o.user_id"); + } + + #[test] + fn left_outer_join() { + let sql = translate_sql( + "SELECT u.name, o.total FROM users u LEFT OUTER JOIN orders o ON u.id = o.user_id", + Dialect::PostgreSQL, + ); + assert_contains(&sql, "LEFT"); + assert_contains(&sql, "JOIN"); + assert_contains(&sql, "ON u.id = o.user_id"); + } + + #[test] + fn multiple_joins() { + let sql = translate_sql( + "SELECT u.name, o.total, p.name AS product FROM users u \ + JOIN orders o ON u.id = o.user_id \ + JOIN products p ON o.product_id = p.id", + Dialect::MySQL, + ); + assert_contains(&sql, "JOIN"); + assert_contains(&sql, "u.id = o.user_id"); + assert_contains(&sql, "o.product_id = p.id"); + } + + #[test] + fn self_join_with_aliases() { + let sql = translate_sql( + "SELECT e.name AS employee, m.name AS manager FROM employees e LEFT JOIN employees m ON e.manager_id = m.id", + Dialect::MySQL, + ); + assert_contains(&sql, "AS employee"); + assert_contains(&sql, "AS manager"); + assert_contains(&sql, "e.manager_id = m.id"); + } + + #[test] + fn cross_join() { + let sql = translate_sql( + "SELECT a.x, b.y FROM t1 a CROSS JOIN t2 b", + Dialect::MySQL, + ); + assert_contains(&sql, "CROSS JOIN"); + } + + #[test] + fn right_join() { + let sql = translate_sql( + "SELECT u.name, o.id FROM users u RIGHT JOIN orders o ON u.id = o.user_id", + Dialect::PostgreSQL, + ); + assert_contains(&sql, "RIGHT JOIN"); + assert_contains(&sql, "ON u.id = o.user_id"); + } + + // ── 3. Set operations ────────────────────────────────────────────────── + + #[test] + fn union() { + let sql = translate_sql( + "SELECT name FROM customers UNION SELECT name FROM suppliers", + Dialect::MySQL, + ); + assert_contains(&sql, "UNION"); + assert_contains(&sql, "customers"); + assert_contains(&sql, "suppliers"); + } + + #[test] + fn union_all() { + let sql = translate_sql( + "SELECT name FROM customers UNION ALL SELECT name FROM suppliers", + Dialect::MySQL, + ); + assert_contains(&sql, "UNION ALL"); + } + + #[test] + fn multiple_unions() { + let sql = translate_sql( + "SELECT name FROM customers UNION SELECT name FROM suppliers UNION SELECT name FROM partners", + Dialect::MySQL, + ); + assert_contains(&sql, "UNION"); + assert_contains(&sql, "customers"); + assert_contains(&sql, "suppliers"); + assert_contains(&sql, "partners"); + } + + #[test] + fn intersect() { + let sql = translate_sql( + "SELECT id FROM a INTERSECT SELECT id FROM b", + Dialect::PostgreSQL, + ); + assert_contains(&sql, "INTERSECT"); + } + + #[test] + fn except() { + let sql = translate_sql( + "SELECT id FROM a EXCEPT SELECT id FROM b", + Dialect::PostgreSQL, + ); + assert_contains(&sql, "EXCEPT"); + } + + // ── 4. CTEs (Common Table Expressions) ───────────────────────────────── + + #[test] + fn simple_cte() { + let sql = translate_sql( + "WITH active_users AS (SELECT * FROM users WHERE active = 1) SELECT * FROM active_users", + Dialect::MySQL, + ); + assert_contains(&sql, "WITH"); + assert_contains(&sql, "active_users"); + assert_contains(&sql, "active = 1"); + } + + #[test] + fn cte_used_in_join() { + let sql = translate_sql( + "WITH recent_orders AS (SELECT user_id, SUM(total) AS total FROM orders GROUP BY user_id) \ + SELECT u.name, r.total FROM users u JOIN recent_orders r ON u.id = r.user_id", + Dialect::MySQL, + ); + assert_contains(&sql, "WITH"); + assert_contains(&sql, "recent_orders"); + assert_contains(&sql, "SUM("); + assert_contains(&sql, "JOIN"); + } + + #[test] + fn multiple_ctes() { + let sql = translate_sql( + "WITH cte1 AS (SELECT 1 AS a), cte2 AS (SELECT 2 AS b) SELECT * FROM cte1, cte2", + Dialect::MySQL, + ); + assert_contains(&sql, "WITH"); + assert_contains(&sql, "cte1"); + assert_contains(&sql, "cte2"); + } + + // ── 5. Complex DML after rewriting ───────────────────────────────────── + + #[test] + fn insert_select() { + let sql = translate_sql( + "INSERT INTO archive (id, name) SELECT id, name FROM users WHERE active = 0", + Dialect::MySQL, + ); + assert_contains(&sql, "INSERT INTO"); + assert_contains(&sql, "archive"); + assert_contains(&sql, "SELECT id, name FROM users"); + assert_contains(&sql, "active = 0"); + } + + #[test] + fn insert_multiple_rows() { + let sql = translate_sql( + "INSERT INTO users (name, age) VALUES ('Alice', 30), ('Bob', 25), ('Carol', 35)", + Dialect::MySQL, + ); + assert_contains(&sql, "INSERT INTO"); + assert_contains(&sql, "Alice"); + assert_contains(&sql, "Bob"); + assert_contains(&sql, "Carol"); + } + + #[test] + fn update_with_subquery_in_where() { + let sql = translate_sql( + "UPDATE users SET active = 0 WHERE id IN (SELECT user_id FROM banned)", + Dialect::MySQL, + ); + assert_contains(&sql, "UPDATE"); + assert_contains(&sql, "active = 0"); + assert_contains(&sql, "IN (SELECT user_id FROM banned)"); + } + + #[test] + fn delete_with_subquery() { + let sql = translate_sql( + "DELETE FROM orders WHERE user_id IN (SELECT id FROM users WHERE active = 0)", + Dialect::MySQL, + ); + assert_contains(&sql, "DELETE"); + assert_contains(&sql, "orders"); + assert_contains(&sql, "IN (SELECT id FROM users"); + } + + #[test] + fn update_multiple_columns() { + let sql = translate_sql( + "UPDATE users SET name = 'Updated', age = 99, active = 0 WHERE id = 42", + Dialect::MySQL, + ); + assert_contains(&sql, "UPDATE users SET"); + assert_contains(&sql, "name = 'Updated'"); + assert_contains(&sql, "age = 99"); + assert_contains(&sql, "active = 0"); + assert_contains(&sql, "id = 42"); + } + + #[test] + fn delete_with_and_or_conditions() { + let sql = translate_sql( + "DELETE FROM events WHERE (status = 'expired' OR status = 'cancelled') AND created_at < '2020-01-01'", + Dialect::MySQL, + ); + assert_contains(&sql, "DELETE"); + assert_contains(&sql, "expired"); + assert_contains(&sql, "cancelled"); + assert_contains(&sql, "AND"); + } + + // ── 6. DDL ───────────────────────────────────────────────────────────── + + #[test] + fn create_table_if_not_exists() { + let sql = translate_sql( + "CREATE TABLE IF NOT EXISTS users (id INTEGER PRIMARY KEY, name TEXT NOT NULL)", + Dialect::MySQL, + ); + assert_contains(&sql, "CREATE TABLE IF NOT EXISTS"); + assert_contains(&sql, "users"); + assert_contains(&sql, "PRIMARY KEY"); + assert_contains(&sql, "NOT NULL"); + } + + #[test] + fn drop_table_if_exists() { + let sql = translate_sql( + "DROP TABLE IF EXISTS old_users", + Dialect::MySQL, + ); + assert_contains(&sql, "DROP TABLE IF EXISTS"); + assert_contains(&sql, "old_users"); + } + + #[test] + fn create_index() { + let sql = translate_sql( + "CREATE INDEX idx_users_email ON users (email)", + Dialect::MySQL, + ); + assert_contains(&sql, "CREATE INDEX"); + assert_contains(&sql, "idx_users_email"); + assert_contains(&sql, "users"); + assert_contains(&sql, "email"); + } + + #[test] + fn create_unique_index() { + let sql = translate_sql( + "CREATE UNIQUE INDEX idx_users_email ON users (email)", + Dialect::MySQL, + ); + assert_contains(&sql, "UNIQUE INDEX"); + assert_contains(&sql, "idx_users_email"); + assert_contains(&sql, "users"); + } + + #[test] + fn drop_table_plain() { + let sql = translate_sql( + "DROP TABLE sessions", + Dialect::MySQL, + ); + assert_contains(&sql, "DROP TABLE"); + assert_contains(&sql, "sessions"); + } + + #[test] + fn create_table_multiple_constraints() { + let sql = translate_sql( + "CREATE TABLE orders (id INTEGER PRIMARY KEY, user_id INTEGER NOT NULL, amount INTEGER DEFAULT 0)", + Dialect::MySQL, + ); + assert_contains(&sql, "PRIMARY KEY"); + assert_contains(&sql, "NOT NULL"); + assert_contains(&sql, "DEFAULT"); + } + + #[test] + fn create_index_composite() { + let sql = translate_sql( + "CREATE INDEX idx_composite ON orders (user_id, created_at)", + Dialect::MySQL, + ); + assert_contains(&sql, "CREATE INDEX"); + assert_contains(&sql, "user_id"); + assert_contains(&sql, "created_at"); + } + + // ── 7. Transaction statements ────────────────────────────────────────── + + #[test] + fn begin_transaction() { + let sql = translate_sql("BEGIN", Dialect::MySQL); + // sqlparser may emit BEGIN TRANSACTION or just BEGIN + let upper = sql.to_ascii_uppercase(); + assert!( + upper.contains("BEGIN"), + "expected BEGIN in: {sql}" + ); + } + + #[test] + fn commit_transaction() { + let sql = translate_sql("COMMIT", Dialect::MySQL); + assert_contains(&sql, "COMMIT"); + } + + #[test] + fn rollback_transaction() { + let sql = translate_sql("ROLLBACK", Dialect::MySQL); + assert_contains(&sql, "ROLLBACK"); + } + + #[test] + fn savepoint() { + let sql = translate_sql("SAVEPOINT sp1", Dialect::MySQL); + assert_contains(&sql, "SAVEPOINT"); + assert_contains(&sql, "sp1"); + } + + #[test] + fn release_savepoint() { + let sql = translate_sql("RELEASE SAVEPOINT sp1", Dialect::MySQL); + assert_contains(&sql, "RELEASE SAVEPOINT"); + assert_contains(&sql, "sp1"); + } + + #[test] + fn rollback_to_savepoint() { + let sql = translate_sql( + "ROLLBACK TO SAVEPOINT sp1", + Dialect::PostgreSQL, + ); + assert_contains(&sql, "ROLLBACK TO SAVEPOINT"); + assert_contains(&sql, "sp1"); + } + + // ── 8. Cross-dialect roundtrip coverage ──────────────────────────────── + + #[test] + fn pg_select_group_by_having() { + let sql = translate_sql( + "SELECT category, AVG(price) FROM products GROUP BY category HAVING AVG(price) > 50", + Dialect::PostgreSQL, + ); + assert_contains(&sql, "GROUP BY"); + assert_contains(&sql, "HAVING"); + assert_contains(&sql, "AVG("); + } + + #[test] + fn pg_cte_with_join() { + let sql = translate_sql( + "WITH top_users AS (SELECT id, name FROM users ORDER BY score DESC LIMIT 10) \ + SELECT t.name, COUNT(o.id) FROM top_users t LEFT JOIN orders o ON t.id = o.user_id GROUP BY t.name", + Dialect::PostgreSQL, + ); + assert_contains(&sql, "WITH"); + assert_contains(&sql, "top_users"); + assert_contains(&sql, "LEFT JOIN"); + assert_contains(&sql, "GROUP BY"); + } + + #[test] + fn tds_select_with_where_and_order() { + let sql = translate_sql( + "SELECT id, name FROM users WHERE active = 1 ORDER BY name", + Dialect::TDS, + ); + assert_contains(&sql, "SELECT"); + assert_contains(&sql, "WHERE"); + assert_contains(&sql, "ORDER BY"); + } + + #[test] + fn pg_insert_select() { + let sql = translate_sql( + "INSERT INTO log (event) SELECT event_name FROM events WHERE created_at > '2024-01-01'", + Dialect::PostgreSQL, + ); + assert_contains(&sql, "INSERT INTO"); + assert_contains(&sql, "SELECT event_name"); + } + + #[test] + fn mysql_create_table_if_not_exists_simple() { + let sql = translate_sql( + "CREATE TABLE IF NOT EXISTS settings (id INTEGER PRIMARY KEY, value TEXT)", + Dialect::MySQL, + ); + assert_contains(&sql, "IF NOT EXISTS"); + } + + #[test] + fn pg_union_all_with_order() { + let sql = translate_sql( + "SELECT name, 'customer' AS type FROM customers UNION ALL SELECT name, 'supplier' AS type FROM suppliers ORDER BY name", + Dialect::PostgreSQL, + ); + assert_contains(&sql, "UNION ALL"); + assert_contains(&sql, "ORDER BY"); + } + + #[test] + fn nested_subquery() { + let sql = translate_sql( + "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders WHERE product_id IN (SELECT id FROM products WHERE category = 'electronics'))", + Dialect::MySQL, + ); + assert_contains(&sql, "IN (SELECT user_id"); + assert_contains(&sql, "IN (SELECT id FROM products"); + assert_contains(&sql, "electronics"); + } + + #[test] + fn select_coalesce() { + let sql = translate_sql( + "SELECT COALESCE(nickname, name, 'Anonymous') FROM users", + Dialect::MySQL, + ); + assert_contains(&sql, "COALESCE"); + assert_contains(&sql, "nickname"); + assert_contains(&sql, "Anonymous"); + } + + #[test] + fn select_with_multiple_subqueries() { + let sql = translate_sql( + "SELECT (SELECT COUNT(*) FROM orders) AS order_count, (SELECT COUNT(*) FROM users) AS user_count", + Dialect::MySQL, + ); + assert_contains(&sql, "SELECT COUNT(*)"); + assert_contains(&sql, "order_count"); + assert_contains(&sql, "user_count"); + } + + #[test] + fn insert_returning_not_supported_mysql() { + // MySQL doesn't support RETURNING, but PostgreSQL does. Test PG. + let sql = translate_sql( + "INSERT INTO users (name) VALUES ('Alice') RETURNING id", + Dialect::PostgreSQL, + ); + assert_contains(&sql, "INSERT INTO"); + assert_contains(&sql, "Alice"); + // sqlparser should preserve RETURNING + assert_contains(&sql, "RETURNING"); + } + + #[test] + fn pg_select_distinct_on_not_available_use_regular_distinct() { + // Regular DISTINCT should work for all dialects. + let sql = translate_sql( + "SELECT DISTINCT status FROM orders", + Dialect::PostgreSQL, + ); + assert_contains(&sql, "DISTINCT"); + assert_contains(&sql, "status"); + } + + #[test] + fn select_having_without_group_by() { + // MySQL allows HAVING without GROUP BY (aggregate over entire table). + let sql = translate_sql( + "SELECT COUNT(*) FROM users HAVING COUNT(*) > 0", + Dialect::MySQL, + ); + assert_contains(&sql, "HAVING"); + assert_contains(&sql, "COUNT(*)"); + } + + #[test] + fn select_null_handling_coalesce_and_ifnull() { + let sql = translate_sql( + "SELECT COALESCE(a, b, c) FROM t", + Dialect::MySQL, + ); + assert_contains(&sql, "COALESCE"); + } + + #[test] + fn select_order_by_expression() { + let sql = translate_sql( + "SELECT name, age FROM users ORDER BY age * 2 DESC", + Dialect::MySQL, + ); + assert_contains(&sql, "ORDER BY"); + assert_contains(&sql, "DESC"); + } + + #[test] + fn create_index_if_not_exists() { + let sql = translate_sql( + "CREATE INDEX IF NOT EXISTS idx_name ON users (name)", + Dialect::PostgreSQL, + ); + assert_contains(&sql, "IF NOT EXISTS"); + assert_contains(&sql, "idx_name"); + } + + #[test] + fn select_group_by_multiple_columns() { + let sql = translate_sql( + "SELECT department, role, COUNT(*) FROM employees GROUP BY department, role", + Dialect::MySQL, + ); + assert_contains(&sql, "GROUP BY"); + assert_contains(&sql, "department"); + assert_contains(&sql, "role"); + assert_contains(&sql, "COUNT(*)"); + } + + #[test] + fn select_aliased_subquery_in_from() { + let sql = translate_sql( + "SELECT t.total_count FROM (SELECT COUNT(*) AS total_count FROM users) AS t", + Dialect::MySQL, + ); + assert_contains(&sql, "total_count"); + assert_contains(&sql, "COUNT(*)"); + } + + #[test] + fn pg_create_unique_index() { + let sql = translate_sql( + "CREATE UNIQUE INDEX idx_email ON accounts (email)", + Dialect::PostgreSQL, + ); + assert_contains(&sql, "UNIQUE INDEX"); + assert_contains(&sql, "idx_email"); + assert_contains(&sql, "accounts"); + } + + #[test] + fn select_case_when_no_else() { + let sql = translate_sql( + "SELECT CASE WHEN age >= 18 THEN 'adult' END FROM users", + Dialect::MySQL, + ); + assert_contains(&sql, "CASE"); + assert_contains(&sql, "WHEN"); + assert_contains(&sql, "THEN"); + assert_contains(&sql, "END"); + } + + #[test] + fn select_case_with_operand() { + let sql = translate_sql( + "SELECT CASE status WHEN 'active' THEN 1 WHEN 'inactive' THEN 0 ELSE -1 END FROM users", + Dialect::MySQL, + ); + assert_contains(&sql, "CASE"); + assert_contains(&sql, "WHEN 'active' THEN 1"); + assert_contains(&sql, "ELSE -1"); + } + + #[test] + fn select_not_in() { + let sql = translate_sql( + "SELECT * FROM users WHERE id NOT IN (1, 2, 3)", + Dialect::MySQL, + ); + assert_contains(&sql, "NOT IN"); + } + + #[test] + fn select_or_conditions() { + let sql = translate_sql( + "SELECT * FROM users WHERE age < 18 OR age > 65", + Dialect::MySQL, + ); + assert_contains(&sql, "OR"); + assert_contains(&sql, "age < 18"); + assert_contains(&sql, "age > 65"); + } + + #[test] + fn multiple_statements_in_one_input() { + let results = translate( + "SELECT 1; SELECT 2", + Dialect::MySQL, + ) + .unwrap(); + assert_eq!(results.len(), 2, "expected 2 results"); + let s1 = match &results[0] { + TranslateResult::Sql(s) => s.clone(), + other => panic!("expected Sql, got: {other:?}"), + }; + let s2 = match &results[1] { + TranslateResult::Sql(s) => s.clone(), + other => panic!("expected Sql, got: {other:?}"), + }; + assert!(s1.contains('1')); + assert!(s2.contains('2')); + } + + #[test] + fn cte_recursive() { + // Recursive CTEs are valid in SQLite. + let sql = translate_sql( + "WITH RECURSIVE cnt(x) AS (SELECT 1 UNION ALL SELECT x + 1 FROM cnt WHERE x < 10) SELECT x FROM cnt", + Dialect::PostgreSQL, + ); + assert_contains(&sql, "WITH RECURSIVE"); + assert_contains(&sql, "cnt"); + assert_contains(&sql, "UNION ALL"); + } + + #[test] + fn pg_begin_commit() { + let sql_begin = translate_sql("BEGIN", Dialect::PostgreSQL); + let sql_commit = translate_sql("COMMIT", Dialect::PostgreSQL); + let upper_begin = sql_begin.to_ascii_uppercase(); + assert!(upper_begin.contains("BEGIN"), "got: {sql_begin}"); + assert_contains(&sql_commit, "COMMIT"); + } + + #[test] + fn insert_default_values() { + let sql = translate_sql( + "INSERT INTO t DEFAULT VALUES", + Dialect::PostgreSQL, + ); + assert_contains(&sql, "INSERT INTO"); + assert_contains(&sql, "DEFAULT VALUES"); + } + + #[test] + fn select_with_table_alias_no_as() { + let sql = translate_sql( + "SELECT u.id, u.name FROM users u WHERE u.id > 0", + Dialect::MySQL, + ); + assert_contains(&sql, "u.id"); + assert_contains(&sql, "u.name"); + } + + #[test] + fn select_nested_case_when() { + let sql = translate_sql( + "SELECT CASE WHEN x > 0 THEN CASE WHEN x > 10 THEN 'big' ELSE 'small' END ELSE 'zero' END FROM t", + Dialect::MySQL, + ); + assert_contains(&sql, "CASE"); + assert_contains(&sql, "'big'"); + assert_contains(&sql, "'small'"); + assert_contains(&sql, "'zero'"); + } + + #[test] + fn select_with_not_exists() { + let sql = translate_sql( + "SELECT * FROM departments d WHERE NOT EXISTS (SELECT 1 FROM employees e WHERE e.dept_id = d.id)", + Dialect::MySQL, + ); + assert_contains(&sql, "NOT EXISTS"); + assert_contains(&sql, "SELECT 1"); + } + + #[test] + fn pg_drop_table_if_exists() { + let sql = translate_sql( + "DROP TABLE IF EXISTS temp_data", + Dialect::PostgreSQL, + ); + assert_contains(&sql, "DROP TABLE IF EXISTS"); + assert_contains(&sql, "temp_data"); + } } From ab7b19ff1496706d70a673b8aa5184f38b06396b Mon Sep 17 00:00:00 2001 From: Luther Monson Date: Sun, 3 May 2026 14:15:36 -0700 Subject: [PATCH 5/7] add unit tests for the LiteWire builder and serve() error paths Tests builder construction, chaining, invalid address handling, the no-frontends error path, and serve() smoke tests for each frontend. --- crates/litewire/src/lib.rs | 188 +++++++++++++++++++++++++++++++++++++ 1 file changed, 188 insertions(+) diff --git a/crates/litewire/src/lib.rs b/crates/litewire/src/lib.rs index c6e5223..4b7567b 100644 --- a/crates/litewire/src/lib.rs +++ b/crates/litewire/src/lib.rs @@ -157,3 +157,191 @@ impl LiteWire { } } } + +#[cfg(test)] +mod tests { + use super::*; + + fn memory_backend() -> backend::Rusqlite { + backend::Rusqlite::memory().unwrap() + } + + // ── Builder construction ─────────────────────────────────────────────── + + #[test] + fn new_does_not_panic() { + let _lw = LiteWire::new(memory_backend()); + } + + #[cfg(feature = "mysql")] + #[test] + fn mysql_builder_returns_self() { + let _lw = LiteWire::new(memory_backend()).mysql("127.0.0.1:3306"); + } + + #[cfg(feature = "postgres")] + #[test] + fn postgres_builder_returns_self() { + let _lw = LiteWire::new(memory_backend()).postgres("127.0.0.1:5432"); + } + + #[cfg(feature = "tds")] + #[test] + fn tds_builder_returns_self() { + let _lw = LiteWire::new(memory_backend()).tds("127.0.0.1:1433"); + } + + #[cfg(feature = "hrana")] + #[test] + fn hrana_builder_returns_self() { + let _lw = LiteWire::new(memory_backend()).hrana("127.0.0.1:8080"); + } + + #[cfg(all(feature = "mysql", feature = "hrana"))] + #[test] + fn builder_chaining_mysql_and_hrana() { + let _lw = LiteWire::new(memory_backend()) + .mysql("127.0.0.1:3306") + .hrana("127.0.0.1:8080"); + } + + #[cfg(all(feature = "mysql", feature = "hrana", feature = "postgres", feature = "tds"))] + #[test] + fn builder_chaining_all_frontends() { + let _lw = LiteWire::new(memory_backend()) + .mysql("127.0.0.1:3306") + .postgres("127.0.0.1:5432") + .tds("127.0.0.1:1433") + .hrana("127.0.0.1:8080"); + } + + // ── Invalid address handling ─────────────────────────────────────────── + + #[cfg(feature = "mysql")] + #[test] + fn mysql_invalid_address_does_not_panic() { + let _lw = LiteWire::new(memory_backend()).mysql("not-an-address"); + } + + #[cfg(feature = "mysql")] + #[test] + fn mysql_empty_address_does_not_panic() { + let _lw = LiteWire::new(memory_backend()).mysql(""); + } + + #[cfg(feature = "postgres")] + #[test] + fn postgres_invalid_address_does_not_panic() { + let _lw = LiteWire::new(memory_backend()).postgres("not-an-address"); + } + + #[cfg(feature = "hrana")] + #[test] + fn hrana_invalid_address_does_not_panic() { + let _lw = LiteWire::new(memory_backend()).hrana("garbage!!!"); + } + + #[cfg(feature = "tds")] + #[test] + fn tds_invalid_address_does_not_panic() { + let _lw = LiteWire::new(memory_backend()).tds(""); + } + + /// An invalid address should result in the listener field remaining `None`, + /// so `serve()` should treat it as if no frontend was configured. + #[cfg(feature = "mysql")] + #[tokio::test] + async fn invalid_address_means_no_frontend() { + let server = LiteWire::new(memory_backend()).mysql("not-an-address"); + let result = server.serve().await; + let err = result.unwrap_err(); + assert!( + err.to_string().contains("no frontends configured"), + "expected 'no frontends configured' error, got: {err}", + ); + } + + // ── serve() with no frontends ────────────────────────────────────────── + + #[tokio::test] + async fn serve_no_frontends_returns_error() { + let server = LiteWire::new(memory_backend()); + let result = server.serve().await; + let err = result.unwrap_err(); + assert!( + err.to_string().contains("no frontends configured"), + "expected 'no frontends configured' error, got: {err}", + ); + } + + // ── serve() smoke tests (server starts and runs) ─────────────────────── + + #[cfg(feature = "mysql")] + #[tokio::test] + async fn serve_starts_mysql() { + let backend = memory_backend(); + let server = LiteWire::new(backend).mysql("127.0.0.1:0"); + // serve() should start without immediately erroring. + // A timeout means the server is running (it blocks until shutdown). + let result = tokio::time::timeout( + std::time::Duration::from_millis(100), + server.serve(), + ) + .await; + assert!(result.is_err(), "should timeout, meaning server is running"); + } + + #[cfg(feature = "hrana")] + #[tokio::test] + async fn serve_starts_hrana() { + let backend = memory_backend(); + let server = LiteWire::new(backend).hrana("127.0.0.1:0"); + let result = tokio::time::timeout( + std::time::Duration::from_millis(100), + server.serve(), + ) + .await; + assert!(result.is_err(), "should timeout, meaning server is running"); + } + + #[cfg(feature = "postgres")] + #[tokio::test] + async fn serve_starts_postgres() { + let backend = memory_backend(); + let server = LiteWire::new(backend).postgres("127.0.0.1:0"); + let result = tokio::time::timeout( + std::time::Duration::from_millis(100), + server.serve(), + ) + .await; + assert!(result.is_err(), "should timeout, meaning server is running"); + } + + #[cfg(feature = "tds")] + #[tokio::test] + async fn serve_starts_tds() { + let backend = memory_backend(); + let server = LiteWire::new(backend).tds("127.0.0.1:0"); + let result = tokio::time::timeout( + std::time::Duration::from_millis(100), + server.serve(), + ) + .await; + assert!(result.is_err(), "should timeout, meaning server is running"); + } + + #[cfg(all(feature = "mysql", feature = "hrana"))] + #[tokio::test] + async fn serve_starts_multiple_frontends() { + let backend = memory_backend(); + let server = LiteWire::new(backend) + .mysql("127.0.0.1:0") + .hrana("127.0.0.1:0"); + let result = tokio::time::timeout( + std::time::Duration::from_millis(100), + server.serve(), + ) + .await; + assert!(result.is_err(), "should timeout, meaning server is running"); + } +} From ef841efaac86be9acc3bc0cb1a4282e9122792c6 Mon Sep 17 00:00:00 2001 From: Luther Monson Date: Sun, 3 May 2026 14:16:17 -0700 Subject: [PATCH 6/7] add unit tests for litewire-mysql handler and type edge cases Add comprehensive tests for the handler module covering ok_response, translate_sql, transaction state logic, param_to_value conversion dispatch, and handler initialization. Add edge case tests for types.rs covering empty strings, whitespace, additional SQL type variants, and case sensitivity for BYTEA. --- crates/litewire-mysql/Cargo.toml | 4 + crates/litewire-mysql/src/handler.rs | 494 +++++++++++++++++++++++++++ crates/litewire-mysql/src/types.rs | 165 +++++++++ 3 files changed, 663 insertions(+) diff --git a/crates/litewire-mysql/Cargo.toml b/crates/litewire-mysql/Cargo.toml index 88aa2a3..539895b 100644 --- a/crates/litewire-mysql/Cargo.toml +++ b/crates/litewire-mysql/Cargo.toml @@ -14,3 +14,7 @@ opensrv-mysql = { workspace = true } tokio = { workspace = true } tracing = { workspace = true } thiserror = { workspace = true } + +[dev-dependencies] +litewire-backend = { workspace = true, features = ["rusqlite"] } +tokio = { workspace = true, features = ["full"] } diff --git a/crates/litewire-mysql/src/handler.rs b/crates/litewire-mysql/src/handler.rs index ccb8b40..8dc9ea4 100644 --- a/crates/litewire-mysql/src/handler.rs +++ b/crates/litewire-mysql/src/handler.rs @@ -353,3 +353,497 @@ impl AsyncMysqlShim for LiteWireHandler { writer.ok().await } } + +#[cfg(test)] +mod tests { + use super::*; + use litewire_backend::Rusqlite; + use std::sync::Arc; + + /// Helper: create a handler backed by an in-memory SQLite database. + fn memory_handler() -> LiteWireHandler { + let backend = Arc::new(Rusqlite::memory().unwrap()) as SharedBackend; + LiteWireHandler::new(backend) + } + + // ── ok_response ──────────────────────────────────────────────────────── + + #[test] + fn ok_response_not_in_transaction() { + let resp = ok_response(1, 2, false); + assert_eq!(resp.affected_rows, 1); + assert_eq!(resp.last_insert_id, 2); + assert!(!resp.status_flags.contains(StatusFlags::SERVER_STATUS_IN_TRANS)); + } + + #[test] + fn ok_response_in_transaction() { + let resp = ok_response(0, 0, true); + assert!(resp.status_flags.contains(StatusFlags::SERVER_STATUS_IN_TRANS)); + } + + #[test] + fn ok_response_zero_values() { + let resp = ok_response(0, 0, false); + assert_eq!(resp.affected_rows, 0); + assert_eq!(resp.last_insert_id, 0); + assert!(resp.status_flags.is_empty()); + } + + #[test] + fn ok_response_large_values() { + let resp = ok_response(u64::MAX, u64::MAX, true); + assert_eq!(resp.affected_rows, u64::MAX); + assert_eq!(resp.last_insert_id, u64::MAX); + assert!(resp.status_flags.contains(StatusFlags::SERVER_STATUS_IN_TRANS)); + } + + // ── translate_sql ────────────────────────────────────────────────────── + + #[test] + fn translate_simple_select() { + let handler = memory_handler(); + let (sql, kind) = handler.translate_sql("SELECT 1").unwrap(); + assert!(!sql.is_empty()); + assert_eq!(kind, StatementKind::Query); + } + + #[test] + fn translate_select_from_table() { + let handler = memory_handler(); + let (sql, kind) = handler + .translate_sql("SELECT id, name FROM users WHERE id = 1") + .unwrap(); + assert!(sql.to_ascii_lowercase().contains("select")); + assert_eq!(kind, StatementKind::Query); + } + + #[test] + fn translate_insert() { + let handler = memory_handler(); + let (sql, kind) = handler + .translate_sql("INSERT INTO users (name) VALUES ('Alice')") + .unwrap(); + assert!(sql.to_ascii_lowercase().contains("insert")); + assert_eq!(kind, StatementKind::Mutation); + } + + #[test] + fn translate_update() { + let handler = memory_handler(); + let (sql, kind) = handler + .translate_sql("UPDATE users SET name = 'Bob' WHERE id = 1") + .unwrap(); + assert!(sql.to_ascii_lowercase().contains("update")); + assert_eq!(kind, StatementKind::Mutation); + } + + #[test] + fn translate_delete() { + let handler = memory_handler(); + let (sql, kind) = handler + .translate_sql("DELETE FROM users WHERE id = 1") + .unwrap(); + assert!(sql.to_ascii_lowercase().contains("delete")); + assert_eq!(kind, StatementKind::Mutation); + } + + #[test] + fn translate_create_table() { + let handler = memory_handler(); + let (sql, kind) = handler + .translate_sql("CREATE TABLE t (id INT PRIMARY KEY, name VARCHAR(255))") + .unwrap(); + assert!(sql.to_ascii_lowercase().contains("create")); + assert_eq!(kind, StatementKind::Ddl); + } + + #[test] + fn translate_begin_returns_transaction() { + let handler = memory_handler(); + let (sql, kind) = handler.translate_sql("BEGIN").unwrap(); + assert!(sql.to_ascii_lowercase().contains("begin")); + assert_eq!(kind, StatementKind::Transaction); + } + + #[test] + fn translate_commit_returns_transaction() { + let handler = memory_handler(); + let (sql, kind) = handler.translate_sql("COMMIT").unwrap(); + assert!(sql.to_ascii_lowercase().contains("commit")); + assert_eq!(kind, StatementKind::Transaction); + } + + #[test] + fn translate_rollback_returns_transaction() { + let handler = memory_handler(); + let (sql, kind) = handler.translate_sql("ROLLBACK").unwrap(); + assert!(sql.to_ascii_lowercase().contains("rollback")); + assert_eq!(kind, StatementKind::Transaction); + } + + #[test] + fn translate_set_names_returns_noop() { + let handler = memory_handler(); + let (sql, kind) = handler.translate_sql("SET NAMES utf8mb4").unwrap(); + // Noop branch returns empty SQL and Other kind. + assert!(sql.is_empty()); + assert_eq!(kind, StatementKind::Other); + } + + #[test] + fn translate_set_character_set_returns_noop() { + let handler = memory_handler(); + let (sql, kind) = handler.translate_sql("SET CHARACTER SET utf8").unwrap(); + assert!(sql.is_empty()); + assert_eq!(kind, StatementKind::Other); + } + + #[test] + fn translate_show_tables_returns_metadata() { + let handler = memory_handler(); + let (sql, kind) = handler.translate_sql("SHOW TABLES").unwrap(); + // Metadata branch returns a SQLite query and Query kind. + assert!(!sql.is_empty()); + assert_eq!(kind, StatementKind::Query); + // The metadata SQL should query sqlite_master. + assert!(sql.contains("sqlite_master")); + } + + #[test] + fn translate_show_columns_returns_metadata() { + let handler = memory_handler(); + let (sql, kind) = handler.translate_sql("SHOW COLUMNS FROM users").unwrap(); + assert!(!sql.is_empty()); + assert_eq!(kind, StatementKind::Query); + } + + #[test] + fn translate_invalid_sql_returns_error() { + let handler = memory_handler(); + let result = handler.translate_sql("NOT VALID SQL !!! @@@ {{{}}"); + assert!(result.is_err()); + } + + #[test] + fn translate_select_with_mysql_backticks() { + let handler = memory_handler(); + let (sql, kind) = handler.translate_sql("SELECT `id` FROM `users`").unwrap(); + assert!(!sql.is_empty()); + assert_eq!(kind, StatementKind::Query); + } + + #[test] + fn translate_select_with_limit() { + let handler = memory_handler(); + let (sql, kind) = handler + .translate_sql("SELECT * FROM users LIMIT 10") + .unwrap(); + assert!(!sql.is_empty()); + assert_eq!(kind, StatementKind::Query); + } + + // ── do_transaction state logic ───────────────────────────────────────── + // + // We cannot directly call `do_transaction` because it requires a + // `QueryResultWriter` that can only be constructed inside opensrv-mysql. + // Instead, we verify the transaction state-update logic that + // `do_transaction` applies after a successful backend.execute(). + // + // The logic under test (from do_transaction lines 139-144): + // let upper = sql.trim().to_ascii_uppercase(); + // if upper.starts_with("BEGIN") || upper.starts_with("START") { + // self.in_transaction = true; + // } else if upper.starts_with("COMMIT") || upper.starts_with("ROLLBACK") { + // self.in_transaction = false; + // } + + /// Apply the same transaction state logic that `do_transaction` uses. + fn apply_transaction_state(in_transaction: &mut bool, sql: &str) { + let upper = sql.trim().to_ascii_uppercase(); + if upper.starts_with("BEGIN") || upper.starts_with("START") { + *in_transaction = true; + } else if upper.starts_with("COMMIT") || upper.starts_with("ROLLBACK") { + *in_transaction = false; + } + } + + #[test] + fn transaction_begin_sets_in_transaction() { + let mut in_tx = false; + apply_transaction_state(&mut in_tx, "BEGIN"); + assert!(in_tx); + } + + #[test] + fn transaction_commit_clears_in_transaction() { + let mut in_tx = true; + apply_transaction_state(&mut in_tx, "COMMIT"); + assert!(!in_tx); + } + + #[test] + fn transaction_rollback_clears_in_transaction() { + let mut in_tx = true; + apply_transaction_state(&mut in_tx, "ROLLBACK"); + assert!(!in_tx); + } + + #[test] + fn transaction_begin_case_insensitive() { + for sql in &["begin", "BEGIN", "Begin", "bEgIn"] { + let mut in_tx = false; + apply_transaction_state(&mut in_tx, sql); + assert!(in_tx, "expected in_transaction=true for '{sql}'"); + } + } + + #[test] + fn transaction_commit_case_insensitive() { + for sql in &["commit", "COMMIT", "Commit"] { + let mut in_tx = true; + apply_transaction_state(&mut in_tx, sql); + assert!(!in_tx, "expected in_transaction=false for '{sql}'"); + } + } + + #[test] + fn transaction_rollback_case_insensitive() { + for sql in &["rollback", "ROLLBACK", "Rollback"] { + let mut in_tx = true; + apply_transaction_state(&mut in_tx, sql); + assert!(!in_tx, "expected in_transaction=false for '{sql}'"); + } + } + + #[test] + fn transaction_start_transaction_variant() { + let mut in_tx = false; + apply_transaction_state(&mut in_tx, "START TRANSACTION"); + assert!(in_tx); + } + + #[test] + fn transaction_begin_with_leading_whitespace() { + let mut in_tx = false; + apply_transaction_state(&mut in_tx, " BEGIN "); + assert!(in_tx); + } + + #[test] + fn transaction_commit_with_leading_whitespace() { + let mut in_tx = true; + apply_transaction_state(&mut in_tx, " COMMIT "); + assert!(!in_tx); + } + + #[test] + fn transaction_full_cycle() { + let mut in_tx = false; + apply_transaction_state(&mut in_tx, "BEGIN"); + assert!(in_tx); + apply_transaction_state(&mut in_tx, "COMMIT"); + assert!(!in_tx); + apply_transaction_state(&mut in_tx, "START TRANSACTION"); + assert!(in_tx); + apply_transaction_state(&mut in_tx, "ROLLBACK"); + assert!(!in_tx); + } + + #[test] + fn transaction_unknown_sql_does_not_change_state() { + let mut in_tx = false; + apply_transaction_state(&mut in_tx, "SELECT 1"); + assert!(!in_tx); + + let mut in_tx = true; + apply_transaction_state(&mut in_tx, "INSERT INTO t VALUES (1)"); + assert!(in_tx); + } + + // ── do_transaction with backend (integration) ────────────────────────── + + #[tokio::test] + async fn transaction_backend_begin_commit() { + let handler = memory_handler(); + // Verify that the backend can execute BEGIN and COMMIT without error. + handler.backend.execute("BEGIN", &[]).await.unwrap(); + handler.backend.execute("COMMIT", &[]).await.unwrap(); + } + + #[tokio::test] + async fn transaction_backend_begin_rollback() { + let handler = memory_handler(); + handler.backend.execute("BEGIN", &[]).await.unwrap(); + handler.backend.execute("ROLLBACK", &[]).await.unwrap(); + } + + // ── handler construction ─────────────────────────────────────────────── + + #[test] + fn handler_initial_state() { + let handler = memory_handler(); + assert!(!handler.in_transaction); + assert!(handler.stmts.is_empty()); + assert_eq!(handler.next_stmt_id, 1); + } + + // ── param_to_value ───────────────────────────────────────────────────── + // + // Testing param_to_value directly requires constructing + // opensrv_mysql::ParamValue, which in turn needs opensrv_mysql::Value. + // The Value struct wraps ValueInner in a private tuple field, and its + // constructors (null(), bytes(), parse_from()) are all pub(crate). + // Therefore we verify the conversion logic by matching against ValueInner + // variants -- the same dispatch that param_to_value performs. + + #[test] + fn param_conversion_null() { + // ValueInner::NULL -> Value::Null + let result = match ValueInner::NULL { + ValueInner::NULL => Value::Null, + _ => unreachable!(), + }; + assert_eq!(result, Value::Null); + } + + #[test] + fn param_conversion_positive_int() { + let result = match ValueInner::Int(42) { + ValueInner::Int(i) => Value::Integer(i), + _ => unreachable!(), + }; + assert_eq!(result, Value::Integer(42)); + } + + #[test] + fn param_conversion_negative_int() { + let result = match ValueInner::Int(-100) { + ValueInner::Int(i) => Value::Integer(i), + _ => unreachable!(), + }; + assert_eq!(result, Value::Integer(-100)); + } + + #[test] + fn param_conversion_zero_int() { + let result = match ValueInner::Int(0) { + ValueInner::Int(i) => Value::Integer(i), + _ => unreachable!(), + }; + assert_eq!(result, Value::Integer(0)); + } + + #[test] + fn param_conversion_uint() { + let result = match ValueInner::UInt(255) { + ValueInner::UInt(u) => Value::Integer(u as i64), + _ => unreachable!(), + }; + assert_eq!(result, Value::Integer(255)); + } + + #[test] + fn param_conversion_uint_large() { + // Large unsigned values that fit in i64. + let val = u64::MAX / 2; + let result = match ValueInner::UInt(val) { + ValueInner::UInt(u) => Value::Integer(u as i64), + _ => unreachable!(), + }; + assert_eq!(result, Value::Integer(val as i64)); + } + + #[test] + fn param_conversion_double() { + let result = match ValueInner::Double(3.14) { + ValueInner::Double(f) => Value::Float(f), + _ => unreachable!(), + }; + assert_eq!(result, Value::Float(3.14)); + } + + #[test] + fn param_conversion_double_negative() { + let result = match ValueInner::Double(-0.5) { + ValueInner::Double(f) => Value::Float(f), + _ => unreachable!(), + }; + assert_eq!(result, Value::Float(-0.5)); + } + + #[test] + fn param_conversion_utf8_bytes() { + let bytes = b"hello world"; + let result = match ValueInner::Bytes(bytes) { + ValueInner::Bytes(b) => match std::str::from_utf8(b) { + Ok(s) => Value::Text(s.to_string()), + Err(_) => Value::Blob(b.to_vec()), + }, + _ => unreachable!(), + }; + assert_eq!(result, Value::Text("hello world".to_string())); + } + + #[test] + fn param_conversion_non_utf8_bytes() { + let bytes: &[u8] = &[0xFF, 0xFE, 0x00, 0x80]; + let result = match ValueInner::Bytes(bytes) { + ValueInner::Bytes(b) => match std::str::from_utf8(b) { + Ok(s) => Value::Text(s.to_string()), + Err(_) => Value::Blob(b.to_vec()), + }, + _ => unreachable!(), + }; + assert_eq!(result, Value::Blob(vec![0xFF, 0xFE, 0x00, 0x80])); + } + + #[test] + fn param_conversion_empty_bytes() { + let bytes: &[u8] = b""; + let result = match ValueInner::Bytes(bytes) { + ValueInner::Bytes(b) => match std::str::from_utf8(b) { + Ok(s) => Value::Text(s.to_string()), + Err(_) => Value::Blob(b.to_vec()), + }, + _ => unreachable!(), + }; + assert_eq!(result, Value::Text(String::new())); + } + + #[test] + fn param_conversion_date_bytes() { + let date_bytes: &[u8] = b"2024-01-15"; + let result = match ValueInner::Date(date_bytes) { + ValueInner::Date(b) | ValueInner::Time(b) | ValueInner::Datetime(b) => { + Value::Text(String::from_utf8_lossy(b).into_owned()) + } + _ => unreachable!(), + }; + assert_eq!(result, Value::Text("2024-01-15".to_string())); + } + + #[test] + fn param_conversion_time_bytes() { + let time_bytes: &[u8] = b"12:30:00"; + let result = match ValueInner::Time(time_bytes) { + ValueInner::Date(b) | ValueInner::Time(b) | ValueInner::Datetime(b) => { + Value::Text(String::from_utf8_lossy(b).into_owned()) + } + _ => unreachable!(), + }; + assert_eq!(result, Value::Text("12:30:00".to_string())); + } + + #[test] + fn param_conversion_datetime_bytes() { + let datetime_bytes: &[u8] = b"2024-01-15 12:30:00"; + let result = match ValueInner::Datetime(datetime_bytes) { + ValueInner::Date(b) | ValueInner::Time(b) | ValueInner::Datetime(b) => { + Value::Text(String::from_utf8_lossy(b).into_owned()) + } + _ => unreachable!(), + }; + assert_eq!(result, Value::Text("2024-01-15 12:30:00".to_string())); + } +} diff --git a/crates/litewire-mysql/src/types.rs b/crates/litewire-mysql/src/types.rs index 4451980..e7ad00d 100644 --- a/crates/litewire-mysql/src/types.rs +++ b/crates/litewire-mysql/src/types.rs @@ -157,4 +157,169 @@ mod tests { ColumnType::MYSQL_TYPE_BLOB ); } + + // ── Edge case tests ──────────────────────────────────────────────────── + + #[test] + fn type_mapping_empty_string() { + // An empty string has no keyword matches, falls through to default. + assert_eq!( + sqlite_to_mysql_column_type(Some("")), + ColumnType::MYSQL_TYPE_VAR_STRING + ); + } + + #[test] + fn type_mapping_whitespace_only() { + // Whitespace-only string does not contain any type keywords. + assert_eq!( + sqlite_to_mysql_column_type(Some(" ")), + ColumnType::MYSQL_TYPE_VAR_STRING + ); + } + + #[test] + fn type_mapping_integer_with_extra_whitespace() { + // uppercased " INTEGER " still contains "INT". + assert_eq!( + sqlite_to_mysql_column_type(Some(" INTEGER ")), + ColumnType::MYSQL_TYPE_LONGLONG + ); + } + + #[test] + fn type_mapping_text_with_extra_whitespace() { + assert_eq!( + sqlite_to_mysql_column_type(Some(" TEXT ")), + ColumnType::MYSQL_TYPE_VAR_STRING + ); + } + + #[test] + fn type_mapping_blob_with_extra_whitespace() { + assert_eq!( + sqlite_to_mysql_column_type(Some(" BLOB ")), + ColumnType::MYSQL_TYPE_BLOB + ); + } + + #[test] + fn type_mapping_real_with_extra_whitespace() { + assert_eq!( + sqlite_to_mysql_column_type(Some(" REAL ")), + ColumnType::MYSQL_TYPE_DOUBLE + ); + } + + #[test] + fn type_mapping_clob() { + assert_eq!( + sqlite_to_mysql_column_type(Some("CLOB")), + ColumnType::MYSQL_TYPE_VAR_STRING + ); + } + + #[test] + fn type_mapping_mediumint() { + // "MEDIUMINT" contains "INT" + assert_eq!( + sqlite_to_mysql_column_type(Some("MEDIUMINT")), + ColumnType::MYSQL_TYPE_LONGLONG + ); + } + + #[test] + fn type_mapping_smallint() { + assert_eq!( + sqlite_to_mysql_column_type(Some("SMALLINT")), + ColumnType::MYSQL_TYPE_LONGLONG + ); + } + + #[test] + fn type_mapping_tinyblob() { + // "TINYBLOB" contains "BLOB" + assert_eq!( + sqlite_to_mysql_column_type(Some("TINYBLOB")), + ColumnType::MYSQL_TYPE_BLOB + ); + } + + #[test] + fn type_mapping_mediumblob() { + assert_eq!( + sqlite_to_mysql_column_type(Some("MEDIUMBLOB")), + ColumnType::MYSQL_TYPE_BLOB + ); + } + + #[test] + fn type_mapping_longblob() { + assert_eq!( + sqlite_to_mysql_column_type(Some("LONGBLOB")), + ColumnType::MYSQL_TYPE_BLOB + ); + } + + #[test] + fn type_mapping_mixed_case_varchar() { + assert_eq!( + sqlite_to_mysql_column_type(Some("VarChar(100)")), + ColumnType::MYSQL_TYPE_VAR_STRING + ); + } + + #[test] + fn type_mapping_nchar() { + // "NCHAR" contains "CHAR" + assert_eq!( + sqlite_to_mysql_column_type(Some("NCHAR(10)")), + ColumnType::MYSQL_TYPE_VAR_STRING + ); + } + + #[test] + fn type_mapping_nvarchar() { + // "NVARCHAR" contains "VARCHAR" + assert_eq!( + sqlite_to_mysql_column_type(Some("NVARCHAR(255)")), + ColumnType::MYSQL_TYPE_VAR_STRING + ); + } + + #[test] + fn type_mapping_numeric_falls_to_default() { + // "NUMERIC" does not contain INT, REAL, FLOAT, DOUBLE, BLOB, TEXT, + // CHAR, CLOB, or VARCHAR -- falls to default. + assert_eq!( + sqlite_to_mysql_column_type(Some("NUMERIC")), + ColumnType::MYSQL_TYPE_VAR_STRING + ); + } + + #[test] + fn type_mapping_boolean_falls_to_default() { + assert_eq!( + sqlite_to_mysql_column_type(Some("BOOLEAN")), + ColumnType::MYSQL_TYPE_VAR_STRING + ); + } + + #[test] + fn type_mapping_bytea_case_insensitive() { + // "bytea" uppercases to "BYTEA" which is checked with == + assert_eq!( + sqlite_to_mysql_column_type(Some("bytea")), + ColumnType::MYSQL_TYPE_BLOB + ); + } + + #[test] + fn type_mapping_int_priority_over_later_checks() { + // "INT" is checked first, so "MEDIUMINT UNSIGNED" still hits INT. + assert_eq!( + sqlite_to_mysql_column_type(Some("MEDIUMINT UNSIGNED")), + ColumnType::MYSQL_TYPE_LONGLONG + ); + } } From 506a76c4cb979f81131155cc18b01cbc15576c65 Mon Sep 17 00:00:00 2001 From: Luther Monson Date: Sun, 3 May 2026 14:17:28 -0700 Subject: [PATCH 7/7] add unit tests for litewire-postgres crate Tests cover sqlite_to_pg_type type mapping (21 tests), value_to_pg_type inference (5 tests), encode_value serialization (11 tests), extract_params portal parameter extraction (15 tests), and pg_error construction (2 tests). Added bytes as a dev-dependency for constructing binary parameter fixtures. --- Cargo.lock | 1 + crates/litewire-postgres/Cargo.toml | 3 + crates/litewire-postgres/src/handler.rs | 342 ++++++++++++++++++++++++ crates/litewire-postgres/src/types.rs | 118 ++++++++ 4 files changed, 464 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index da16fc1..ac15130 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1664,6 +1664,7 @@ name = "litewire-postgres" version = "0.1.0" dependencies = [ "async-trait", + "bytes", "futures", "litewire-backend", "litewire-translate", diff --git a/crates/litewire-postgres/Cargo.toml b/crates/litewire-postgres/Cargo.toml index 231ac3c..d13daae 100644 --- a/crates/litewire-postgres/Cargo.toml +++ b/crates/litewire-postgres/Cargo.toml @@ -15,3 +15,6 @@ tokio = { workspace = true } tracing = { workspace = true } thiserror = { workspace = true } futures = { workspace = true } + +[dev-dependencies] +bytes = { workspace = true } diff --git a/crates/litewire-postgres/src/handler.rs b/crates/litewire-postgres/src/handler.rs index 203a2a2..6777616 100644 --- a/crates/litewire-postgres/src/handler.rs +++ b/crates/litewire-postgres/src/handler.rs @@ -481,3 +481,345 @@ impl ExtendedQueryHandler for PostgresHandler { } } } + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + use bytes::Bytes; + use pgwire::api::portal::{Format, Portal}; + use pgwire::api::results::{FieldFormat, FieldInfo}; + use pgwire::api::stmt::StoredStatement; + + // ── value_to_pg_type ─────────────────────────────────────────────────── + + #[test] + fn value_to_pg_type_null() { + assert_eq!(value_to_pg_type(&Value::Null), Type::TEXT); + } + + #[test] + fn value_to_pg_type_integer() { + assert_eq!(value_to_pg_type(&Value::Integer(42)), Type::INT8); + } + + #[test] + fn value_to_pg_type_float() { + assert_eq!(value_to_pg_type(&Value::Float(3.14)), Type::FLOAT8); + } + + #[test] + fn value_to_pg_type_text() { + assert_eq!( + value_to_pg_type(&Value::Text("hello".into())), + Type::TEXT + ); + } + + #[test] + fn value_to_pg_type_blob() { + assert_eq!( + value_to_pg_type(&Value::Blob(vec![1, 2, 3])), + Type::BYTEA + ); + } + + // ── encode_value ─────────────────────────────────────────────────────── + + /// Helper: create a single-column schema and encode one value into it. + fn encode_single(val: &Value, pg_type: Type) -> PgWireResult<()> { + let field = FieldInfo::new( + "col".into(), + None, + None, + pg_type.clone(), + FieldFormat::Binary, + ); + let schema = Arc::new(vec![FieldInfo::new( + "col".into(), + None, + None, + pg_type, + FieldFormat::Binary, + )]); + let mut encoder = DataRowEncoder::new(schema); + encode_value(&mut encoder, val, &field)?; + let _row = encoder.finish()?; + Ok(()) + } + + #[test] + fn encode_null() { + encode_single(&Value::Null, Type::TEXT).expect("encoding null should succeed"); + } + + #[test] + fn encode_integer() { + encode_single(&Value::Integer(42), Type::INT8) + .expect("encoding integer should succeed"); + } + + #[test] + fn encode_negative_integer() { + encode_single(&Value::Integer(-100), Type::INT8) + .expect("encoding negative integer should succeed"); + } + + #[test] + fn encode_float() { + encode_single(&Value::Float(3.14), Type::FLOAT8) + .expect("encoding float should succeed"); + } + + #[test] + fn encode_text() { + encode_single(&Value::Text("hello world".into()), Type::TEXT) + .expect("encoding text should succeed"); + } + + #[test] + fn encode_empty_text() { + encode_single(&Value::Text(String::new()), Type::TEXT) + .expect("encoding empty text should succeed"); + } + + #[test] + fn encode_blob() { + encode_single(&Value::Blob(vec![0xDE, 0xAD, 0xBE, 0xEF]), Type::BYTEA) + .expect("encoding blob should succeed"); + } + + #[test] + fn encode_empty_blob() { + encode_single(&Value::Blob(vec![]), Type::BYTEA) + .expect("encoding empty blob should succeed"); + } + + #[test] + fn encode_bool_true() { + // Integer 1 with BOOL type should encode as true. + encode_single(&Value::Integer(1), Type::BOOL) + .expect("encoding bool true should succeed"); + } + + #[test] + fn encode_bool_false() { + // Integer 0 with BOOL type should encode as false. + encode_single(&Value::Integer(0), Type::BOOL) + .expect("encoding bool false should succeed"); + } + + #[test] + fn encode_bool_nonzero() { + // Any nonzero integer with BOOL type should encode as true. + encode_single(&Value::Integer(42), Type::BOOL) + .expect("encoding nonzero bool should succeed"); + } + + // ── extract_params ───────────────────────────────────────────────────── + + /// Helper: build a Portal with given parameter types and binary parameter bytes. + fn make_portal( + param_types: Vec, + parameters: Vec>, + ) -> Portal { + let stmt = Arc::new(StoredStatement::new( + String::new(), + String::new(), + param_types, + )); + let mut portal = Portal::::default(); + portal.statement = stmt; + portal.parameter_format = Format::UnifiedBinary; + portal.parameters = parameters; + portal + } + + #[test] + fn extract_params_bool_true() { + let portal = make_portal( + vec![Type::BOOL], + vec![Some(Bytes::from_static(&[1]))], + ); + let params = extract_params(&portal); + assert_eq!(params, vec![Value::Integer(1)]); + } + + #[test] + fn extract_params_bool_false() { + let portal = make_portal( + vec![Type::BOOL], + vec![Some(Bytes::from_static(&[0]))], + ); + let params = extract_params(&portal); + assert_eq!(params, vec![Value::Integer(0)]); + } + + #[test] + fn extract_params_int2() { + let portal = make_portal( + vec![Type::INT2], + vec![Some(Bytes::from(42_i16.to_be_bytes().to_vec()))], + ); + let params = extract_params(&portal); + assert_eq!(params, vec![Value::Integer(42)]); + } + + #[test] + fn extract_params_int4() { + let portal = make_portal( + vec![Type::INT4], + vec![Some(Bytes::from(1000_i32.to_be_bytes().to_vec()))], + ); + let params = extract_params(&portal); + assert_eq!(params, vec![Value::Integer(1000)]); + } + + #[test] + fn extract_params_int8() { + let portal = make_portal( + vec![Type::INT8], + vec![Some(Bytes::from( + i64::MAX.to_be_bytes().to_vec(), + ))], + ); + let params = extract_params(&portal); + assert_eq!(params, vec![Value::Integer(i64::MAX)]); + } + + #[test] + fn extract_params_int8_negative() { + let portal = make_portal( + vec![Type::INT8], + vec![Some(Bytes::from((-99_i64).to_be_bytes().to_vec()))], + ); + let params = extract_params(&portal); + assert_eq!(params, vec![Value::Integer(-99)]); + } + + #[test] + fn extract_params_float4() { + let portal = make_portal( + vec![Type::FLOAT4], + vec![Some(Bytes::from( + 2.5_f32.to_be_bytes().to_vec(), + ))], + ); + let params = extract_params(&portal); + assert_eq!(params, vec![Value::Float(2.5)]); + } + + #[test] + fn extract_params_float8() { + let portal = make_portal( + vec![Type::FLOAT8], + vec![Some(Bytes::from( + 3.14_f64.to_be_bytes().to_vec(), + ))], + ); + let params = extract_params(&portal); + assert_eq!(params, vec![Value::Float(3.14)]); + } + + #[test] + fn extract_params_text() { + let portal = make_portal( + vec![Type::TEXT], + vec![Some(Bytes::from("hello"))], + ); + let params = extract_params(&portal); + assert_eq!(params, vec![Value::Text("hello".into())]); + } + + #[test] + fn extract_params_bytea() { + let data = vec![0xCA, 0xFE, 0xBA, 0xBE]; + let portal = make_portal( + vec![Type::BYTEA], + vec![Some(Bytes::from(data.clone()))], + ); + let params = extract_params(&portal); + assert_eq!(params, vec![Value::Blob(data)]); + } + + #[test] + fn extract_params_null() { + let portal = make_portal(vec![Type::TEXT], vec![None]); + let params = extract_params(&portal); + assert_eq!(params, vec![Value::Null]); + } + + #[test] + fn extract_params_null_int() { + let portal = make_portal(vec![Type::INT8], vec![None]); + let params = extract_params(&portal); + assert_eq!(params, vec![Value::Null]); + } + + #[test] + fn extract_params_unknown_type_defaults_to_text() { + // VARCHAR is not explicitly handled, so it should fall through to the + // default TEXT branch. + let portal = make_portal( + vec![Type::VARCHAR], + vec![Some(Bytes::from("fallback"))], + ); + let params = extract_params(&portal); + assert_eq!(params, vec![Value::Text("fallback".into())]); + } + + #[test] + fn extract_params_multiple() { + let portal = make_portal( + vec![Type::INT8, Type::TEXT, Type::BOOL], + vec![ + Some(Bytes::from(7_i64.to_be_bytes().to_vec())), + Some(Bytes::from("world")), + Some(Bytes::from_static(&[1])), + ], + ); + let params = extract_params(&portal); + assert_eq!( + params, + vec![ + Value::Integer(7), + Value::Text("world".into()), + Value::Integer(1), + ] + ); + } + + #[test] + fn extract_params_empty() { + let portal = make_portal(vec![], vec![]); + let params = extract_params(&portal); + assert!(params.is_empty()); + } + + // ── pg_error ─────────────────────────────────────────────────────────── + + #[test] + fn pg_error_produces_user_error() { + let err = pg_error("something went wrong"); + match err { + PgWireError::UserError(info) => { + assert_eq!(info.severity, "ERROR"); + assert_eq!(info.code, "XX000"); + assert_eq!(info.message, "something went wrong"); + } + other => panic!("expected UserError, got: {other:?}"), + } + } + + #[test] + fn pg_error_empty_message() { + let err = pg_error(""); + match err { + PgWireError::UserError(info) => { + assert_eq!(info.message, ""); + } + other => panic!("expected UserError, got: {other:?}"), + } + } +} diff --git a/crates/litewire-postgres/src/types.rs b/crates/litewire-postgres/src/types.rs index f0605ad..e87b6f9 100644 --- a/crates/litewire-postgres/src/types.rs +++ b/crates/litewire-postgres/src/types.rs @@ -29,3 +29,121 @@ pub fn sqlite_to_pg_type(decltype: Option<&str>) -> Type { // TEXT, VARCHAR, CHAR, CLOB, DATE, DATETIME, TIMESTAMP, etc. Type::TEXT } + +#[cfg(test)] +mod tests { + use super::*; + + // ── None / missing decltype ──────────────────────────────────────────── + #[test] + fn none_defaults_to_text() { + assert_eq!(sqlite_to_pg_type(None), Type::TEXT); + } + + // ── Integer family ───────────────────────────────────────────────────── + #[test] + fn integer_maps_to_int8() { + assert_eq!(sqlite_to_pg_type(Some("INTEGER")), Type::INT8); + } + + #[test] + fn int_maps_to_int8() { + assert_eq!(sqlite_to_pg_type(Some("INT")), Type::INT8); + } + + #[test] + fn bigint_maps_to_int8() { + assert_eq!(sqlite_to_pg_type(Some("BIGINT")), Type::INT8); + } + + // ── Float family ─────────────────────────────────────────────────────── + #[test] + fn real_maps_to_float8() { + assert_eq!(sqlite_to_pg_type(Some("REAL")), Type::FLOAT8); + } + + #[test] + fn float_maps_to_float8() { + assert_eq!(sqlite_to_pg_type(Some("FLOAT")), Type::FLOAT8); + } + + #[test] + fn double_maps_to_float8() { + assert_eq!(sqlite_to_pg_type(Some("DOUBLE")), Type::FLOAT8); + } + + // ── Text family ──────────────────────────────────────────────────────── + #[test] + fn text_maps_to_text() { + assert_eq!(sqlite_to_pg_type(Some("TEXT")), Type::TEXT); + } + + #[test] + fn varchar_maps_to_text() { + assert_eq!(sqlite_to_pg_type(Some("VARCHAR")), Type::TEXT); + } + + #[test] + fn char_maps_to_text() { + assert_eq!(sqlite_to_pg_type(Some("CHAR")), Type::TEXT); + } + + #[test] + fn clob_maps_to_text() { + assert_eq!(sqlite_to_pg_type(Some("CLOB")), Type::TEXT); + } + + // ── Binary family ────────────────────────────────────────────────────── + #[test] + fn blob_maps_to_bytea() { + assert_eq!(sqlite_to_pg_type(Some("BLOB")), Type::BYTEA); + } + + #[test] + fn bytea_maps_to_bytea() { + assert_eq!(sqlite_to_pg_type(Some("BYTEA")), Type::BYTEA); + } + + // ── Boolean ──────────────────────────────────────────────────────────── + #[test] + fn bool_maps_to_bool() { + assert_eq!(sqlite_to_pg_type(Some("BOOL")), Type::BOOL); + } + + #[test] + fn boolean_maps_to_bool() { + assert_eq!(sqlite_to_pg_type(Some("BOOLEAN")), Type::BOOL); + } + + // ── Case insensitivity ───────────────────────────────────────────────── + #[test] + fn lowercase_integer() { + assert_eq!(sqlite_to_pg_type(Some("integer")), Type::INT8); + } + + #[test] + fn mixed_case_integer() { + assert_eq!(sqlite_to_pg_type(Some("Integer")), Type::INT8); + } + + #[test] + fn lowercase_real() { + assert_eq!(sqlite_to_pg_type(Some("real")), Type::FLOAT8); + } + + #[test] + fn lowercase_blob() { + assert_eq!(sqlite_to_pg_type(Some("blob")), Type::BYTEA); + } + + // ── Unknown / fallback ───────────────────────────────────────────────── + #[test] + fn unknown_type_defaults_to_text() { + assert_eq!(sqlite_to_pg_type(Some("FOOBAR")), Type::TEXT); + } + + #[test] + fn empty_string_defaults_to_text() { + assert_eq!(sqlite_to_pg_type(Some("")), Type::TEXT); + } +}