Skip to content
1 change: 1 addition & 0 deletions rust/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ rust_test(
deps = [
"@crates//:syn",
"@crates//:proc-macro2",
"@crates//:rand",
],
)

Expand Down
261 changes: 228 additions & 33 deletions rust/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,62 +343,73 @@ impl fmt::Display for StructLiteral {

impl StringLiteral {
pub fn unescape(&self) -> Result<String> {
self.process_unescape(|bytes, _buf, rest| match bytes[1] {
BSP => Ok(('\x08', 2)),
TAB => Ok(('\x09', 2)),
LF_ => Ok(('\x0a', 2)),
FF_ => Ok(('\x0c', 2)),
CR_ => Ok(('\x0d', 2)),
c @ (b'"' | b'\'' | b'\\') => Ok((c as char, 2)),
b'u' => todo!("Unicode escape handling"),
_ => Err(TypeQLError::InvalidStringEscape {
full_string: rest.to_owned(),
escape: format!(r"\{}", rest.chars().nth(1).unwrap()),
self.process_unescape(|bytes| {
if bytes.len() < 2 {
return Err(1);
}
match bytes[1] {
BSP => Ok(('\x08', 2)),
TAB => Ok(('\x09', 2)),
LF_ => Ok(('\x0a', 2)),
FF_ => Ok(('\x0c', 2)),
CR_ => Ok(('\x0d', 2)),
c @ (b'"' | b'\'' | b'\\') => Ok((c as char, 2)),
b'u' => {
let escape = &bytes[2..std::cmp::min(6, bytes.len())];
match decode_four_hex_bytes(escape) {
Some(char) => Ok((char, 6)),
None => Err(6),
}
}
_ => Err(2),
}
.into()),
})
}

pub fn unescape_regex(&self) -> Result<String> {
self.process_unescape(|bytes, _, _| match bytes[1] {
c @ b'"' => Ok((c as char, 2)),
self.process_unescape(|bytes| match bytes.get(1) {
Some(b'"') => Ok(('"', 2)),
_ => Ok(('\\', 1)),
})
}

fn process_unescape<F>(&self, escape_handler: F) -> Result<String>
where
F: Fn(&[u8], &mut String, &str) -> Result<(char, usize)>,
F: Fn(&[u8]) -> std::result::Result<(char, usize), usize>,
{
let bytes = self.value.as_bytes();
assert_eq!(bytes[0], bytes[bytes.len() - 1]);
assert!(matches!(bytes[0], b'\'' | b'"'));

let escaped_string = &self.value[1..self.value.len() - 1];
let mut buf = String::with_capacity(escaped_string.len());
let mut rest = escaped_string;

let mut buf = Vec::with_capacity(escaped_string.len());
let mut rest: &[u8] = escaped_string.as_bytes();
while !rest.is_empty() {
let (char, escaped_len) = if rest.as_bytes()[0] == b'\\' {
let bytes = rest.as_bytes();

if bytes.len() < 2 {
return Err(TypeQLError::InvalidStringEscape {
full_string: escaped_string.to_owned(),
escape: String::from(r"\"),
let escaped_len = if rest[0] == b'\\' {
match escape_handler(rest) {
Ok((char, escaped_len)) => {
let start = buf.len();
buf.resize(buf.len() + char.len_utf8(), 0);
char.encode_utf8(&mut buf[start..]);
rest = &rest[escaped_len..];
}
Err(considered_escape_seq_length) => {
let offset = escaped_string.len() - rest.len();
let considered_escape_sequence =
escaped_string[offset..].chars().take(considered_escape_seq_length).collect();
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've not changed this, but I also don't get why it's semantically wrong.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, ceil_char_boundary doesn't look stable in the version of rust we're using.

return Err(TypeQLError::InvalidStringEscape {
full_string: escaped_string.to_owned(),
escape: considered_escape_sequence,
}
.into());
}
.into());
}

escape_handler(bytes, &mut buf, escaped_string)?
} else {
let char = rest.chars().next().expect("string is non-empty");
(char, char.len_utf8())
buf.push(rest[0]);
rest = &rest[1..];
};
buf.push(char);
rest = &rest[escaped_len..];
}
Ok(buf)
Ok(String::from_utf8(buf).expect("Expected valid utf8").to_owned())
}
}

Expand All @@ -407,3 +418,187 @@ const TAB: u8 = b't';
const LF_: u8 = b'n';
const FF_: u8 = b'f';
const CR_: u8 = b'r';

#[allow(arithmetic_overflow)]
fn decode_four_hex_bytes(bytes: &[u8]) -> Option<char> {
if bytes.len() == 4 {
let u32_le: u32 = 0u32
| (bytes[0] as char).to_digit(16)? << 12
| (bytes[1] as char).to_digit(16)? << 8
| (bytes[2] as char).to_digit(16)? << 4
| (bytes[3] as char).to_digit(16)? << 0;
debug_assert!(char::from_u32(u32_le).is_some());
char::from_u32(u32_le)
} else {
None
}
}

#[cfg(test)]
pub mod tests {
use crate::{
value::{StringLiteral, TypeQLError},
Result,
};

fn parse_to_string_literal(escaped: &str) -> StringLiteral {
let crate::ValueLiteral::String(parsed) = crate::parse_value(escaped).unwrap() else {
panic!("Not parsed as string");
};
parsed
}

#[test]
fn test_unescape_regex() {
{
let escaped = r#""a\"b\"c""#;
let unescaped = parse_to_string_literal(escaped).unescape_regex().unwrap();
assert_eq!(unescaped.as_str(), r#"a"b"c"#);
}
{
let escaped = r#""abc\123""#;
let unescaped = parse_to_string_literal(escaped).unescape_regex().unwrap();
assert_eq!(unescaped.as_str(), r#"abc\123"#);
}
// Cases that fail at parsing
{
let escaped = r#""abc\""#;
assert!(crate::parse_value(escaped).is_err()); // Parsing fails as incomplete string literal
let string_literal = StringLiteral { value: escaped.to_owned() };
let unescaped = string_literal.unescape_regex().unwrap();
assert_eq!(unescaped.as_str(), r#"abc\"#);
}
}

fn assert_unescapes_to(escaped: &str, expected: &str) {
let unescaped = parse_to_string_literal(escaped).unescape().unwrap();
assert_eq!(unescaped, expected);
}

fn assert_unescape_errors(escaped: &str, expected_escape_sequence: &str) {
let error = parse_to_string_literal(escaped).unescape().unwrap_err();
let TypeQLError::InvalidStringEscape { escape, .. } = &error.errors()[0] else {
panic!("Wrong error type. Was {error:?}")
};
assert_eq!(escape, expected_escape_sequence);
}

#[test]
fn test_unescape() {
// Succeeds
assert_unescapes_to(r#""a\tb\tc""#, "a\tb\tc"); // works
assert_unescapes_to(r#""a\"b\"c""#, r#"a"b"c"#); // works
assert_unescapes_to(r#""a\'b\'c""#, r#"a'b'c"#); // works
assert_unescapes_to(r#""a\\b\\c""#, r#"a\b\c"#); // works
// - Unicode
assert_unescapes_to(r#""abc \u0ca0\u005f\u0ca0""#, "abc ಠ_ಠ"); // works
assert_unescapes_to(r#""abc \u0CA0\u005F\u0CA0""#, "abc ಠ_ಠ"); // caps
assert_unescapes_to(r#""abc \u0CA01234""#, "abc ಠ1234"); // consumes only 4

// Errors
assert_unescape_errors(r#""ab\c""#, r"\c"); // Invalid escape

// - Unicode
assert_unescape_errors(r#""abc \u""#, r"\u"); // Not enough bytes
assert_unescape_errors(r#""abc \u012""#, r"\u012"); // Not enough bytes
assert_unescape_errors(r#""abc \uwu/ abc""#, r"\uwu/ "); // Invalid hex
assert_unescape_errors(r#""abc \uΣ12Σ abc""#, r"\uΣ12Σ"); // Invalid hex, 4 chars more than 4 bytes
assert_unescape_errors(r#""abc \u123Σ abc""#, r"\u123Σ"); // Invalid hex, 4 chars more than 4 bytes

// Cases that fail at parsing
{
let escaped = r#""abc\""#;
assert!(crate::parse_value(escaped).is_err()); // Parsing fails as incomplete string literal
let string_literal = StringLiteral { value: escaped.to_owned() };
let error = string_literal.unescape().unwrap_err();
let TypeQLError::InvalidStringEscape { escape, .. } = &error.errors()[0] else {
panic!("Wrong error type. Was {error:?}")
};
assert_eq!(escape, r#"\"#);
}
}

#[ignore]
#[test]
fn time_unescape_ascii() {
let text = generate_string(TIME_UNESCAPE_TEXT_LEN, |x| 32 + (x % 94));
time_unescape(text);
}

#[ignore]
#[test]
fn time_unescape_unicode() {
// assert_eq!(None, (0..0x07ff).filter(|x| char::from_u32(*x).is_none()).next());
let text = generate_string(TIME_UNESCAPE_TEXT_LEN, move |x| x & 0x07ff);
time_unescape(text);
}

const TIME_UNESCAPE_TEXT_LEN: usize = 100000;
fn time_unescape(text: String) {
use std::time::Instant;
let iters = 10000;

let string_literal = StringLiteral { value: text };
let start = Instant::now();
for _ in 0..iters {
string_literal.unescape().unwrap();
}
let end = Instant::now();
println!(
"{iters} on string of length {} iters in {}",
string_literal.value.as_str().len(),
(end - start).as_secs_f64()
)
}

fn generate_string(length: usize, mapper: fn(u32) -> u32) -> String {
use rand::{thread_rng, Rng, RngCore};
let mut rng = thread_rng();
let capacity: i64 = (1.2 * length as f64).ceil() as i64;
let mut text = String::with_capacity(capacity as usize);
text.push('"');
let mut sanity: i64 = capacity;
while text.as_str().len() < length + 1 && sanity >= 0 {
sanity -= 1;
match char::from_u32(mapper(rng.next_u32())) {
Some('\\') => {
text.push('\\');
text.push('\\');
}
Some('\'') => {
text.push('\\');
text.push('\'');
}
Some('\"') => {
text.push('\\');
text.push('\"');
}
Some('\x08') => {
text.push('\\');
text.push('b');
}
Some('\x09') => {
text.push('\\');
text.push('t');
}
Some('\x0a') => {
text.push('\\');
text.push('n');
}
Some('\x0c') => {
text.push('\\');
text.push('f');
}
Some('\x0d') => {
text.push('\\');
text.push('r');
}
Some(ch) => text.push(ch),
None => {}
}
}
text.push('"');
assert!(text.as_str().len() > length && text.as_str().len() < length + 10);
text
}
}