From b32c48053f13b547c4d651b82660c016165d1102 Mon Sep 17 00:00:00 2001 From: Vincent Roy Date: Wed, 18 Jun 2025 16:53:00 +0200 Subject: [PATCH] Expose AST in validation method --- sds/Cargo.toml | 2 +- sds/src/lib.rs | 4 ++- sds/src/parser/ast.rs | 52 ++++++++++++++++--------------- sds/src/validation.rs | 71 ++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 102 insertions(+), 27 deletions(-) diff --git a/sds/Cargo.toml b/sds/Cargo.toml index 7560fe9c..215a178b 100644 --- a/sds/Cargo.toml +++ b/sds/Cargo.toml @@ -27,7 +27,7 @@ regex-automata = "0.4.7" # Switch over to the original repo when this issue is resolved: https://github.com/rust-lang/regex/issues/1241 regex-automata-fork = { git = "https://github.com/fbryden/regex", rev = "6952250af962ca3e364da47382b16dba9c703431", package = "regex-automata" } regex-syntax = "0.7.5" -serde = { version = "1.0", features = ["derive"] } +serde = { version = "1.0", features = ["derive", "rc"] } serde_with = "3.6.1" strum = { version = "0.25", features = ["derive"] } thiserror = "1.0.58" diff --git a/sds/src/lib.rs b/sds/src/lib.rs index 7e45c30e..51afca76 100644 --- a/sds/src/lib.rs +++ b/sds/src/lib.rs @@ -40,6 +40,7 @@ pub use path::{Path, PathSegment}; pub use rule_match::{ReplacementType, RuleMatch}; pub use scanner::shared_pool::{SharedPool, SharedPoolGuard}; +pub use parser::{ast::Ast, regex_parser::parse_regex_pattern}; pub use scanner::error::MatchValidationError; pub use scanner::{ config::RuleConfig, @@ -52,7 +53,8 @@ pub use scanner::{ }; pub use scoped_ruleset::ExclusionCheck; pub use validation::{ - get_regex_complexity_estimate_very_slow, validate_regex, RegexValidationError, + get_regex_complexity_estimate_very_slow, validate_regex, validate_regex_and_get_ast, + RegexValidationError, }; #[cfg(any(feature = "testing", feature = "bench"))] diff --git a/sds/src/parser/ast.rs b/sds/src/parser/ast.rs index 6ce33902..007b7ef0 100644 --- a/sds/src/parser/ast.rs +++ b/sds/src/parser/ast.rs @@ -1,23 +1,25 @@ +use serde::{Deserialize, Serialize}; use std::rc::Rc; /// The Abstract Syntax Tree describing a regex pattern. The AST is designed /// to preserve behavior, but doesn't necessarily preserve the exact syntax. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(tag = "type", content = "content")] pub enum Ast { Empty, Literal(Literal), Concat(Vec), Group(Rc), CharacterClass(CharacterClass), - // May be empty Alternation(Vec), Repetition(Repetition), Assertion(AssertionType), Flags(Flags), } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Serialize, Deserialize)] pub struct Literal { + #[serde(rename = "value")] pub c: char, // whether a literal is escaped or not can change the behavior in some cases, @@ -25,31 +27,32 @@ pub struct Literal { pub escaped: bool, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(tag = "group_type", content = "content")] pub enum Group { Capturing(CaptureGroup), NonCapturing(NonCapturingGroup), NamedCapturing(NamedCapturingGroup), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct CaptureGroup { pub inner: Ast, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct NonCapturingGroup { pub flags: Flags, pub inner: Ast, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct NamedCapturingGroup { pub name: String, pub inner: Ast, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub enum CharacterClass { Bracket(BracketCharacterClass), Perl(PerlCharacterClass), @@ -61,13 +64,13 @@ pub enum CharacterClass { UnicodeProperty(UnicodePropertyClass), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct UnicodePropertyClass { pub negate: bool, pub name: String, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub enum QuantifierKind { /// * ZeroOrMore, @@ -83,13 +86,13 @@ pub enum QuantifierKind { OneOrMore, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Quantifier { pub lazy: bool, pub kind: QuantifierKind, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub enum PerlCharacterClass { Digit, Space, @@ -99,13 +102,13 @@ pub enum PerlCharacterClass { NonWord, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct BracketCharacterClass { pub negated: bool, pub items: Vec, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub enum BracketCharacterClassItem { Literal(char), Range(char, char), @@ -118,13 +121,14 @@ pub enum BracketCharacterClassItem { NotVerticalWhitespace, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct AsciiClass { pub negated: bool, pub kind: AsciiClassKind, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] pub enum AsciiClassKind { Alnum, Alpha, @@ -142,13 +146,16 @@ pub enum AsciiClassKind { Xdigit, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Repetition { + #[serde(rename = "quantifier")] pub quantifier: Quantifier, + #[serde(rename = "expression")] pub inner: Rc, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] pub enum AssertionType { /// \b WordBoundary, @@ -167,20 +174,17 @@ pub enum AssertionType { /// \z EndText, - - /// \Z EndTextOptionalNewline, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Flags { - /// Flags before a "-" pub add: Vec, - /// Flags after a "-" pub remove: Vec, } -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] pub enum Flag { /// i CaseInsensitive, diff --git a/sds/src/validation.rs b/sds/src/validation.rs index 89921b65..fababbc4 100644 --- a/sds/src/validation.rs +++ b/sds/src/validation.rs @@ -1,5 +1,6 @@ use crate::normalization::rust_regex_adapter::{convert_to_rust_regex, QUANTIFIER_LIMIT}; use crate::parser::error::ParseError; +use crate::parser::regex_parser::parse_regex_pattern; use regex_automata::meta::{self}; use thiserror::Error; @@ -40,6 +41,16 @@ pub fn validate_regex(input: &str) -> Result<(), RegexValidationError> { validate_and_create_regex(input).map(|_| ()) } +/// Checks that a regex pattern is valid for using in an SDS scanner and return the AST if valid. +pub fn validate_regex_and_get_ast( + input: &str, +) -> Result { + // This is the same as `validate_and_create_regex`, but removes the actual Regex type + // to create a more stable API for external users of the crate. + let sds_ast = parse_regex_pattern(input)?; + Ok(sds_ast) +} + pub fn get_regex_complexity_estimate_very_slow(input: &str) -> Result { // The regex crate doesn't directly give you access to the "complexity", but it does // reject if it's too large, so we can binary search to find the limit. @@ -115,7 +126,7 @@ fn build_regex( mod test { use crate::validation::{ get_regex_complexity_estimate_very_slow, validate_and_create_regex, validate_regex, - RegexValidationError, + validate_regex_and_get_ast, RegexValidationError, }; #[test] @@ -183,4 +194,62 @@ mod test { Ok(1_040_136) ); } + + #[test] + fn test_parse_regex_pattern() { + let pattern: &'static str = "^(?:\\w|b)?"; + let ast = validate_regex_and_get_ast(pattern).unwrap(); + let json = serde_json::to_string_pretty(&ast).unwrap(); + assert_eq!( + r###"{ + "type": "Concat", + "content": [ + { + "type": "Assertion", + "content": "startline" + }, + { + "type": "Repetition", + "content": { + "quantifier": { + "lazy": false, + "kind": "ZeroOrOne" + }, + "expression": { + "type": "Group", + "content": { + "group_type": "NonCapturing", + "content": { + "flags": { + "add": [], + "remove": [] + }, + "inner": { + "type": "Alternation", + "content": [ + { + "type": "CharacterClass", + "content": { + "Perl": "Word" + } + }, + { + "type": "Literal", + "content": { + "value": "b", + "escaped": false + } + } + ] + } + } + } + } + } + } + ] +}"###, + json + ); + } }