diff --git a/rust/BUILD b/rust/BUILD index afa26bb6..cb7ee4c5 100644 --- a/rust/BUILD +++ b/rust/BUILD @@ -40,6 +40,7 @@ rust_test( deps = [ "@crates//:syn", "@crates//:proc-macro2", + "@crates//:rand", ], ) diff --git a/rust/value.rs b/rust/value.rs index 564d749d..98783e20 100644 --- a/rust/value.rs +++ b/rust/value.rs @@ -343,62 +343,73 @@ impl fmt::Display for StructLiteral { impl StringLiteral { pub fn unescape(&self) -> Result { - self.process_unescape(|bytes, _buf, rest| match bytes[1] { - BSP => Ok(('\x08', 2)), - TAB => Ok(('\x09', 2)), - LF_ => Ok(('\x0a', 2)), - FF_ => Ok(('\x0c', 2)), - CR_ => Ok(('\x0d', 2)), - c @ (b'"' | b'\'' | b'\\') => Ok((c as char, 2)), - b'u' => todo!("Unicode escape handling"), - _ => Err(TypeQLError::InvalidStringEscape { - full_string: rest.to_owned(), - escape: format!(r"\{}", rest.chars().nth(1).unwrap()), + self.process_unescape(|bytes| { + if bytes.len() < 2 { + return Err(1); + } + match bytes[1] { + BSP => Ok(('\x08', 2)), + TAB => Ok(('\x09', 2)), + LF_ => Ok(('\x0a', 2)), + FF_ => Ok(('\x0c', 2)), + CR_ => Ok(('\x0d', 2)), + c @ (b'"' | b'\'' | b'\\') => Ok((c as char, 2)), + b'u' => { + let escape = &bytes[2..std::cmp::min(6, bytes.len())]; + match decode_four_hex_bytes(escape) { + Some(char) => Ok((char, 6)), + None => Err(6), + } + } + _ => Err(2), } - .into()), }) } pub fn unescape_regex(&self) -> Result { - self.process_unescape(|bytes, _, _| match bytes[1] { - c @ b'"' => Ok((c as char, 2)), + self.process_unescape(|bytes| match bytes.get(1) { + Some(b'"') => Ok(('"', 2)), _ => Ok(('\\', 1)), }) } fn process_unescape(&self, escape_handler: F) -> Result where - F: Fn(&[u8], &mut String, &str) -> Result<(char, usize)>, + F: Fn(&[u8]) -> std::result::Result<(char, usize), usize>, { let bytes = self.value.as_bytes(); assert_eq!(bytes[0], bytes[bytes.len() - 1]); assert!(matches!(bytes[0], b'\'' | b'"')); let escaped_string = &self.value[1..self.value.len() - 1]; - let mut buf = String::with_capacity(escaped_string.len()); - let mut rest = escaped_string; - + let mut buf = Vec::with_capacity(escaped_string.len()); + let mut rest: &[u8] = escaped_string.as_bytes(); while !rest.is_empty() { - let (char, escaped_len) = if rest.as_bytes()[0] == b'\\' { - let bytes = rest.as_bytes(); - - if bytes.len() < 2 { - return Err(TypeQLError::InvalidStringEscape { - full_string: escaped_string.to_owned(), - escape: String::from(r"\"), + let escaped_len = if rest[0] == b'\\' { + match escape_handler(rest) { + Ok((char, escaped_len)) => { + let start = buf.len(); + buf.resize(buf.len() + char.len_utf8(), 0); + char.encode_utf8(&mut buf[start..]); + rest = &rest[escaped_len..]; + } + Err(considered_escape_seq_length) => { + let offset = escaped_string.len() - rest.len(); + let considered_escape_sequence = + escaped_string[offset..].chars().take(considered_escape_seq_length).collect(); + return Err(TypeQLError::InvalidStringEscape { + full_string: escaped_string.to_owned(), + escape: considered_escape_sequence, + } + .into()); } - .into()); } - - escape_handler(bytes, &mut buf, escaped_string)? } else { - let char = rest.chars().next().expect("string is non-empty"); - (char, char.len_utf8()) + buf.push(rest[0]); + rest = &rest[1..]; }; - buf.push(char); - rest = &rest[escaped_len..]; } - Ok(buf) + Ok(String::from_utf8(buf).expect("Expected valid utf8").to_owned()) } } @@ -407,3 +418,187 @@ const TAB: u8 = b't'; const LF_: u8 = b'n'; const FF_: u8 = b'f'; const CR_: u8 = b'r'; + +#[allow(arithmetic_overflow)] +fn decode_four_hex_bytes(bytes: &[u8]) -> Option { + if bytes.len() == 4 { + let u32_le: u32 = 0u32 + | (bytes[0] as char).to_digit(16)? << 12 + | (bytes[1] as char).to_digit(16)? << 8 + | (bytes[2] as char).to_digit(16)? << 4 + | (bytes[3] as char).to_digit(16)? << 0; + debug_assert!(char::from_u32(u32_le).is_some()); + char::from_u32(u32_le) + } else { + None + } +} + +#[cfg(test)] +pub mod tests { + use crate::{ + value::{StringLiteral, TypeQLError}, + Result, + }; + + fn parse_to_string_literal(escaped: &str) -> StringLiteral { + let crate::ValueLiteral::String(parsed) = crate::parse_value(escaped).unwrap() else { + panic!("Not parsed as string"); + }; + parsed + } + + #[test] + fn test_unescape_regex() { + { + let escaped = r#""a\"b\"c""#; + let unescaped = parse_to_string_literal(escaped).unescape_regex().unwrap(); + assert_eq!(unescaped.as_str(), r#"a"b"c"#); + } + { + let escaped = r#""abc\123""#; + let unescaped = parse_to_string_literal(escaped).unescape_regex().unwrap(); + assert_eq!(unescaped.as_str(), r#"abc\123"#); + } + // Cases that fail at parsing + { + let escaped = r#""abc\""#; + assert!(crate::parse_value(escaped).is_err()); // Parsing fails as incomplete string literal + let string_literal = StringLiteral { value: escaped.to_owned() }; + let unescaped = string_literal.unescape_regex().unwrap(); + assert_eq!(unescaped.as_str(), r#"abc\"#); + } + } + + fn assert_unescapes_to(escaped: &str, expected: &str) { + let unescaped = parse_to_string_literal(escaped).unescape().unwrap(); + assert_eq!(unescaped, expected); + } + + fn assert_unescape_errors(escaped: &str, expected_escape_sequence: &str) { + let error = parse_to_string_literal(escaped).unescape().unwrap_err(); + let TypeQLError::InvalidStringEscape { escape, .. } = &error.errors()[0] else { + panic!("Wrong error type. Was {error:?}") + }; + assert_eq!(escape, expected_escape_sequence); + } + + #[test] + fn test_unescape() { + // Succeeds + assert_unescapes_to(r#""a\tb\tc""#, "a\tb\tc"); // works + assert_unescapes_to(r#""a\"b\"c""#, r#"a"b"c"#); // works + assert_unescapes_to(r#""a\'b\'c""#, r#"a'b'c"#); // works + assert_unescapes_to(r#""a\\b\\c""#, r#"a\b\c"#); // works + // - Unicode + assert_unescapes_to(r#""abc \u0ca0\u005f\u0ca0""#, "abc ಠ_ಠ"); // works + assert_unescapes_to(r#""abc \u0CA0\u005F\u0CA0""#, "abc ಠ_ಠ"); // caps + assert_unescapes_to(r#""abc \u0CA01234""#, "abc ಠ1234"); // consumes only 4 + + // Errors + assert_unescape_errors(r#""ab\c""#, r"\c"); // Invalid escape + + // - Unicode + assert_unescape_errors(r#""abc \u""#, r"\u"); // Not enough bytes + assert_unescape_errors(r#""abc \u012""#, r"\u012"); // Not enough bytes + assert_unescape_errors(r#""abc \uwu/ abc""#, r"\uwu/ "); // Invalid hex + assert_unescape_errors(r#""abc \uΣ12Σ abc""#, r"\uΣ12Σ"); // Invalid hex, 4 chars more than 4 bytes + assert_unescape_errors(r#""abc \u123Σ abc""#, r"\u123Σ"); // Invalid hex, 4 chars more than 4 bytes + + // Cases that fail at parsing + { + let escaped = r#""abc\""#; + assert!(crate::parse_value(escaped).is_err()); // Parsing fails as incomplete string literal + let string_literal = StringLiteral { value: escaped.to_owned() }; + let error = string_literal.unescape().unwrap_err(); + let TypeQLError::InvalidStringEscape { escape, .. } = &error.errors()[0] else { + panic!("Wrong error type. Was {error:?}") + }; + assert_eq!(escape, r#"\"#); + } + } + + #[ignore] + #[test] + fn time_unescape_ascii() { + let text = generate_string(TIME_UNESCAPE_TEXT_LEN, |x| 32 + (x % 94)); + time_unescape(text); + } + + #[ignore] + #[test] + fn time_unescape_unicode() { + // assert_eq!(None, (0..0x07ff).filter(|x| char::from_u32(*x).is_none()).next()); + let text = generate_string(TIME_UNESCAPE_TEXT_LEN, move |x| x & 0x07ff); + time_unescape(text); + } + + const TIME_UNESCAPE_TEXT_LEN: usize = 100000; + fn time_unescape(text: String) { + use std::time::Instant; + let iters = 10000; + + let string_literal = StringLiteral { value: text }; + let start = Instant::now(); + for _ in 0..iters { + string_literal.unescape().unwrap(); + } + let end = Instant::now(); + println!( + "{iters} on string of length {} iters in {}", + string_literal.value.as_str().len(), + (end - start).as_secs_f64() + ) + } + + fn generate_string(length: usize, mapper: fn(u32) -> u32) -> String { + use rand::{thread_rng, Rng, RngCore}; + let mut rng = thread_rng(); + let capacity: i64 = (1.2 * length as f64).ceil() as i64; + let mut text = String::with_capacity(capacity as usize); + text.push('"'); + let mut sanity: i64 = capacity; + while text.as_str().len() < length + 1 && sanity >= 0 { + sanity -= 1; + match char::from_u32(mapper(rng.next_u32())) { + Some('\\') => { + text.push('\\'); + text.push('\\'); + } + Some('\'') => { + text.push('\\'); + text.push('\''); + } + Some('\"') => { + text.push('\\'); + text.push('\"'); + } + Some('\x08') => { + text.push('\\'); + text.push('b'); + } + Some('\x09') => { + text.push('\\'); + text.push('t'); + } + Some('\x0a') => { + text.push('\\'); + text.push('n'); + } + Some('\x0c') => { + text.push('\\'); + text.push('f'); + } + Some('\x0d') => { + text.push('\\'); + text.push('r'); + } + Some(ch) => text.push(ch), + None => {} + } + } + text.push('"'); + assert!(text.as_str().len() > length && text.as_str().len() < length + 10); + text + } +}