Skip to content

Commit b6a762e

Browse files
Parser: fix exponential parse time on speculative prefix parsing
1 parent 182eae8 commit b6a762e

3 files changed

Lines changed: 74 additions & 2 deletions

File tree

sqlparser_bench/benches/sqlparser_bench.rs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use criterion::{criterion_group, criterion_main, Criterion};
19-
use sqlparser::dialect::GenericDialect;
19+
use sqlparser::dialect::{GenericDialect, PostgreSqlDialect};
2020
use sqlparser::keywords::Keyword;
2121
use sqlparser::parser::Parser;
2222
use sqlparser::tokenizer::{Span, Word};
@@ -177,11 +177,34 @@ fn parse_compound_chain(c: &mut Criterion) {
177177
group.finish();
178178
}
179179

180+
/// Benchmark parsing pathological `IF(<keyword-fn>(<keyword-fn>(...x` chains
181+
/// that previously caused 2^N work in `parse_prefix`. Each nested
182+
/// `current_time(` segment used to be explored twice at every level (once via
183+
/// the speculative reserved-word arm, once via the unreserved-word fallback),
184+
/// doubling work per level. Post-fix the cost is linear in chain length.
185+
fn parse_prefix_keyword_call_chain(c: &mut Criterion) {
186+
let mut group = c.benchmark_group("parse_prefix_keyword_call_chain");
187+
let dialect = PostgreSqlDialect {};
188+
189+
for &n in &[10usize, 20, 30] {
190+
let sql = String::from("if(") + &"current_time(".repeat(n) + "x";
191+
192+
group.bench_function(format!("chain_{n}"), |b| {
193+
b.iter(|| {
194+
let _ = Parser::parse_sql(&dialect, std::hint::black_box(&sql));
195+
});
196+
});
197+
}
198+
199+
group.finish();
200+
}
201+
180202
criterion_group!(
181203
benches,
182204
basic_queries,
183205
word_to_ident,
184206
parse_many_identifiers,
185-
parse_compound_chain
207+
parse_compound_chain,
208+
parse_prefix_keyword_call_chain
186209
);
187210
criterion_main!(benches);

src/parser/mod.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#[cfg(not(feature = "std"))]
1616
use alloc::{
1717
boxed::Box,
18+
collections::BTreeMap,
1819
format,
1920
string::{String, ToString},
2021
vec,
@@ -24,6 +25,9 @@ use core::{
2425
fmt::{self, Display},
2526
str::FromStr,
2627
};
28+
#[cfg(feature = "std")]
29+
use std::collections::BTreeMap;
30+
2731
use helpers::attached_token::AttachedToken;
2832

2933
use log::debug;
@@ -359,6 +363,9 @@ pub struct Parser<'a> {
359363
options: ParserOptions,
360364
/// Ensures the stack does not overflow by limiting recursion depth.
361365
recursion_counter: RecursionCounter,
366+
/// Cached errors from failed `parse_prefix` calls, keyed by start
367+
/// position. See [`Parser::parse_prefix`] for the 2^N pattern this guards.
368+
failed_prefix_positions: BTreeMap<usize, ParserError>,
362369
}
363370

364371
impl<'a> Parser<'a> {
@@ -385,6 +392,7 @@ impl<'a> Parser<'a> {
385392
dialect,
386393
recursion_counter: RecursionCounter::new(DEFAULT_REMAINING_DEPTH),
387394
options: ParserOptions::new().with_trailing_commas(dialect.supports_trailing_commas()),
395+
failed_prefix_positions: BTreeMap::new(),
388396
}
389397
}
390398

@@ -446,6 +454,7 @@ impl<'a> Parser<'a> {
446454
pub fn with_tokens_with_locations(mut self, tokens: Vec<TokenWithSpan>) -> Self {
447455
self.tokens = tokens;
448456
self.index = 0;
457+
self.failed_prefix_positions.clear();
449458
self
450459
}
451460

@@ -1717,6 +1726,23 @@ impl<'a> Parser<'a> {
17171726
return prefix;
17181727
}
17191728

1729+
// Memoize failed attempts to break 2^N speculation: both the
1730+
// reserved-word and unreserved-word arms can recurse into the
1731+
// same downstream position, so without this short-circuit
1732+
// inputs like `IF(current_time(current_time(...x` re-walk the
1733+
// chain at every level.
1734+
let start_index = self.index;
1735+
if let Some(cached) = self.failed_prefix_positions.get(&start_index) {
1736+
return Err(cached.clone());
1737+
}
1738+
let result = self.parse_prefix_inner();
1739+
if let Err(ref e) = result {
1740+
self.failed_prefix_positions.insert(start_index, e.clone());
1741+
}
1742+
result
1743+
}
1744+
1745+
fn parse_prefix_inner(&mut self) -> Result<Expr, ParserError> {
17201746
// PostgreSQL allows any string literal to be preceded by a type name, indicating that the
17211747
// string literal represents a literal of that type. Some examples:
17221748
//

tests/sqlparser_common.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19004,3 +19004,26 @@ fn parse_compound_chain_no_exponential_blowup() {
1900419004
rx.recv_timeout(Duration::from_secs(5))
1900519005
.expect("parser should reject this quickly, not loop exponentially");
1900619006
}
19007+
19008+
/// Regression test for the 2^N parse-time blowup in `parse_prefix` on inputs
19009+
/// like `IF(current_time(current_time(...x`. Each nested `current_time(` used
19010+
/// to be explored twice at every level (once via the speculative reserved-word
19011+
/// arm, once via the unreserved-word fallback), doubling work per level.
19012+
/// Post-fix the failing parse short-circuits via the position-keyed cache.
19013+
#[test]
19014+
fn parse_prefix_keyword_call_chain_no_exponential_blowup() {
19015+
use std::sync::mpsc;
19016+
use std::thread;
19017+
use std::time::Duration;
19018+
19019+
let sql = String::from("if(") + &"current_time(".repeat(30) + "x";
19020+
19021+
let (tx, rx) = mpsc::channel();
19022+
thread::spawn(move || {
19023+
let _ = Parser::parse_sql(&PostgreSqlDialect {}, &sql);
19024+
let _ = tx.send(());
19025+
});
19026+
19027+
rx.recv_timeout(Duration::from_secs(5))
19028+
.expect("parser should reject this quickly, not loop exponentially");
19029+
}

0 commit comments

Comments
 (0)