diff --git a/README.md b/README.md index 294f3c9..d0f5f96 100644 --- a/README.md +++ b/README.md @@ -45,14 +45,16 @@ Benchmark on CPython 3.13, walking the AST of `difflib.py` (~2000 lines, | implementation | min time | relative | | -------------------------- | -------- | -------- | -| `ast.walk` (stdlib) | ~2.3 ms | 1× | -| pure-Python equivalent | ~1.0 ms | ~2× | -| `fast_walk.walk_dfs` | ~18 µs | ~130× | -| `fast_walk.walk_unordered` | ~13 µs | ~180× | +| `ast.walk` (stdlib) | ~1.9 ms | 1× | +| pure-Python equivalent | ~400 µs | ~5× | +| `fast_walk.walk_dfs` | ~5.6 µs | ~340× | +| `fast_walk.walk_unordered` | ~4.3 µs | ~440× | Both `fast_walk` entry points are semantically equivalent to `list(ast.walk(node))` — they return the same set of AST nodes. They -differ only in visit order. +differ only in visit order. User-attached attributes outside `_fields` +(e.g. a `.parent` back-reference set by an AST transformer) are +ignored, matching `ast.walk`'s behaviour. ## Development diff --git a/src/lib.rs b/src/lib.rs index 9e3bd3d..9837ae8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,17 +12,115 @@ mod pydict; -use std::cell::Cell; +use std::cell::{Cell, RefCell}; use std::sync::atomic::{AtomicBool, Ordering}; use pyo3::exceptions::PyDeprecationWarning; -use pyo3::ffi::{self, PyDictObject, PyListObject, PyObject, PyTypeObject}; +use pyo3::ffi::{self, PyListObject, PyObject, PyTypeObject}; use pyo3::types::{PyList, PyModule, PyType}; use pyo3::{PyTypeInfo, prelude::*}; -/// Reverse iterator over the values of a Python dict whose keys are all -/// strings — the layout used by instance `__dict__`s. Reads the -/// `PyDictKeysObject` entry table directly and skips null (deleted) slots. +/// Open-addressed, direct-mapped lookup from `*mut PyTypeObject` to an +/// AST-classification code. Specialized for the ~130 `ast.AST` +/// subclasses; populated once per thread at first walk and then +/// read-only on the hot path. +/// +/// **Value encoding** (chosen to collapse the two hot-path predicates +/// "is this an AST subclass?" and "how many `_fields` does it have?" +/// into one L1 load): +/// +/// - `0` — not an AST type (or absent from the table). Used for +/// primitive values like `str`, `int`, `None`, `bytes`, and also for +/// any type we haven't seen, so a missing key is indistinguishable +/// from an explicit "not AST" entry — which is exactly what we want. +/// - `n` for `1 <= n <= 8` — AST type with `n - 1` `_fields`. This +/// lets a single `table.get(t) > 0` check replace the old +/// `issubclass_of_ast` walk over `tp_base`, and `table.get(t) - 1` +/// give the `_fields` count. +/// +/// Stdlib AST types top out at 7 `_fields`, so `u8` is more than +/// enough; the table stays at 256 bytes of values + 2 KB of keys = one +/// L1-resident data structure for the whole walk. +/// +/// Layout: parallel `keys` and `values` arrays of `SIZE` slots each. +/// Keys are u64 (pointer as integer); empty slot is `key == 0`. Values +/// are u8. Index function: `(ptr >> 4) & (SIZE - 1)` — type objects +/// are allocator-aligned, so shifting right by 4 drops the alignment +/// zeros and gives well-distributed indices across our ~130 live +/// types. +const FIELD_TABLE_SIZE: usize = 256; // power of two, load factor ~0.5 +const FIELD_TABLE_MASK: usize = FIELD_TABLE_SIZE - 1; + +struct FieldTable { + keys: [u64; FIELD_TABLE_SIZE], + values: [u8; FIELD_TABLE_SIZE], +} + +impl FieldTable { + fn new() -> Self { + Self { + keys: [0; FIELD_TABLE_SIZE], + values: [0; FIELD_TABLE_SIZE], + } + } + + /// Store an AST type with its `_fields` length. The value written + /// is `n_fields + 1`; see the struct-level docs for the encoding. + fn insert_ast(&mut self, ptr: *mut PyTypeObject, n_fields: u8) { + let key = ptr as u64; + debug_assert!(key != 0, "null type pointer"); + debug_assert!(n_fields < u8::MAX, "n_fields would overflow encoding"); + let encoded = n_fields + 1; + let mut idx = ((key >> 4) as usize) & FIELD_TABLE_MASK; + // Linear probe until we find the key or an empty slot. Load + // factor 0.5 keeps the expected probe length ≲ 2. + loop { + let k = self.keys[idx]; + if k == key { + self.values[idx] = encoded; + return; + } + if k == 0 { + self.keys[idx] = key; + self.values[idx] = encoded; + return; + } + idx = (idx + 1) & FIELD_TABLE_MASK; + } + } + + /// Raw encoded value: `0` if not an AST type, `n_fields + 1` + /// otherwise. Hot path uses this directly so "is AST?" (`> 0`) and + /// "how many fields?" (`- 1`) share a single load. + #[inline(always)] + fn lookup(&self, ptr: *mut PyTypeObject) -> u8 { + let key = ptr as u64; + let mut idx = ((key >> 4) as usize) & FIELD_TABLE_MASK; + loop { + // SAFETY: `idx` is always masked to `FIELD_TABLE_MASK`, so it + // stays within the fixed-size array. + let k = unsafe { *self.keys.get_unchecked(idx) }; + if k == key { + return unsafe { *self.values.get_unchecked(idx) }; + } + if k == 0 { + return 0; + } + idx = (idx + 1) & FIELD_TABLE_MASK; + } + } +} + +/// Reverse iterator over the first `limit` values of a Python dict whose +/// keys are all strings — the layout used by instance `__dict__`s. Reads +/// the `PyDictKeysObject` entry table directly and skips null (deleted) +/// slots. +/// +/// The `limit` parameter is how we walk only the `_fields` portion of an +/// AST node's dict: for parsed ASTs, CPython stores keys in the order +/// `_fields ++ _attributes ++ user_added`, so passing `limit = len(_fields)` +/// guarantees we only see syntactic children — no `lineno`/`col_offset` +/// ints and no user-attached `.parent` cycles. pub struct ReverseDictValuesIter { entries: *const pydict::PyDictUnicodeEntry, current: usize, @@ -32,15 +130,17 @@ impl ReverseDictValuesIter { /// # Safety /// /// - `obj` must be a valid pointer to a `PyDictObject` whose keys are - /// all unicode strings (i.e. a split/combined-unicode dict). + /// all unicode strings (i.e. a combined-unicode dict). /// - The dictionary must outlive the iterator and must not be mutated /// while iterating. - pub unsafe fn new(obj: *mut PyDictObject) -> Self { + /// - `limit` must not exceed `dk_nentries`; caller is responsible for + /// clamping when working from an external count (e.g. `_fields` len). + pub unsafe fn new(obj: *mut ffi::PyDictObject, limit: usize) -> Self { unsafe { let dict = &*obj; let keys = &*dict.ma_keys.cast::(); let entries = keys.unicode_entries(); - let n = keys.dk_nentries as usize; + let n = (keys.dk_nentries as usize).min(limit); Self { entries, current: n, @@ -78,35 +178,6 @@ fn get_instance_dict_fast(obj: *mut PyObject) -> Option<*mut PyObject> { } } -/// Check whether `subtype` is a subclass of `ast.AST` within the first two -/// levels of the MRO. Every stdlib AST node is `Concrete -> ast.expr/stmt -/// -> ast.AST` or `Concrete -> ast.AST`, so two hops suffice. -/// -/// Performance notes baked in here: -/// - Early-exit on `first_supertype == PyBaseObject_Type`: primitives like -/// `str`, `NoneType`, `float`, `bytes` inherit directly from `object`, -/// AST subclasses never do. This skips the scattered second `tp_base` -/// load on ~13% of items in typical ASTs. -fn issubclass_of_ast( - subtype: *mut PyTypeObject, - base_ast_and_expr_type: (*mut PyTypeObject, *mut PyTypeObject), -) -> bool { - let first_supertype = unsafe { (*subtype).tp_base }; - if first_supertype.is_null() { - return false; - } - let py_object_type = &raw mut ffi::PyBaseObject_Type; - if first_supertype == py_object_type { - return false; - } - let (base_ast_type, base_expr_type) = base_ast_and_expr_type; - if first_supertype == base_ast_type { - return true; - } - let second_supertype = unsafe { (*first_supertype).tp_base }; - second_supertype == base_ast_type || second_supertype == base_expr_type -} - /// L1 prefetch hint. No-op on non-x86_64 targets — the Python extension /// builds and runs identically without it, just without the cache-miss /// hiding that benefits `walk_unordered`. @@ -147,33 +218,48 @@ unsafe fn ma_keys_of(node: *mut PyObject) -> Option<*const u8> { } } -/// Per-node body shared by both traversals: enumerate the node's instance -/// dict, pushing AST children onto `stack` and descending into list -/// attributes (`body`, `args`, `decorator_list`, ...) to push any AST -/// items found there. +/// Per-node body shared by both traversals: enumerate the `_fields` +/// slots of the node's instance dict and push AST children onto +/// `stack`. For parsed ASTs, CPython stores dict keys in the order +/// `_fields ++ _attributes ++ user_added`, so limiting the scan to the +/// first `len(_fields)` entries skips the `_attributes` ints (lineno/ +/// col_offset/...) and any user-attached metadata (including cycle- +/// inducing `.parent` back-references) in a single loop bound. /// -/// The `int` fast-path skips ~57% of items in a real Python AST (lineno / -/// col_offset / end_lineno / end_col_offset appear on every node). -/// Expanding the fast-path to str/None/float/bytes was measured and -/// reverted — the extra compares cost more on the AST-hit items than -/// they save on the primitive items. +/// Per-value "is this AST?" checks stay on `issubclass_of_ast` rather +/// than the table: the `tp_base == PyBaseObject_Type` early-exit +/// already catches every primitive (str/int/None/...) in one load, and +/// a table probe costs the same on average — substituting one for the +/// other was measured to regress. #[inline(always)] unsafe fn process_node( current_node: *mut PyObject, base_ast_and_expr_type: (*mut PyTypeObject, *mut PyTypeObject), py_list_type: *mut PyTypeObject, - py_long_type: *mut PyTypeObject, + field_table: &FieldTable, stack: &mut Vec<*mut PyObject>, ) { + let type_ptr = unsafe { ffi::Py_TYPE(current_node) }; + let encoded = field_table.lookup(type_ptr); + // 0 == not an AST type we know about. Shouldn't normally happen + // (only AST nodes reach the stack) but guards any caller that + // seeds the walk with a non-AST root. + if encoded == 0 { + return; + } + let n_fields = (encoded - 1) as usize; + if n_fields == 0 { + return; + } + let Some(dict) = get_instance_dict_fast(current_node) else { return; }; - for item_ptr in unsafe { ReverseDictValuesIter::new(dict.cast::()) } { + for item_ptr in + unsafe { ReverseDictValuesIter::new(dict.cast::(), n_fields) } + { let item_type = unsafe { ffi::Py_TYPE(item_ptr) }; - if item_type == py_long_type { - continue; - } if item_type == py_list_type { let list = item_ptr as *mut PyListObject; let length = unsafe { (*(list as *mut ffi::PyVarObject)).ob_size }; @@ -191,15 +277,44 @@ unsafe fn process_node( } } +/// Check whether `subtype` is a subclass of `ast.AST` within the first two +/// levels of the MRO. Every stdlib AST node is `Concrete -> ast.expr/stmt +/// -> ast.AST` or `Concrete -> ast.AST`, so two hops suffice. +/// +/// Performance notes baked in here: +/// - Early-exit on `first_supertype == PyBaseObject_Type`: primitives like +/// `str`, `NoneType`, `float`, `bytes` inherit directly from `object`, +/// AST subclasses never do. This skips the scattered second `tp_base` +/// load on ~40% of items (non-AST values in `_fields` slots). +fn issubclass_of_ast( + subtype: *mut PyTypeObject, + base_ast_and_expr_type: (*mut PyTypeObject, *mut PyTypeObject), +) -> bool { + let first_supertype = unsafe { (*subtype).tp_base }; + if first_supertype.is_null() { + return false; + } + let py_object_type = &raw mut ffi::PyBaseObject_Type; + if first_supertype == py_object_type { + return false; + } + let (base_ast_type, base_expr_type) = base_ast_and_expr_type; + if first_supertype == base_ast_type { + return true; + } + let second_supertype = unsafe { (*first_supertype).tp_base }; + second_supertype == base_ast_type || second_supertype == base_expr_type +} + /// Strict depth-first pre-order traversal. fn walk_node_dfs( node: *mut PyObject, base_ast_and_expr_type: (*mut PyTypeObject, *mut PyTypeObject), py_list_type: *mut PyTypeObject, + field_table: &FieldTable, result_list: &mut Vec<*mut PyObject>, ) -> PyResult<()> { let mut stack = vec![node]; - let py_long_type = &raw mut ffi::PyLong_Type; while let Some(current_node) = stack.pop() { result_list.push(current_node); @@ -208,7 +323,7 @@ fn walk_node_dfs( current_node, base_ast_and_expr_type, py_list_type, - py_long_type, + field_table, &mut stack, ); } @@ -229,11 +344,11 @@ fn walk_node_unordered( node: *mut PyObject, base_ast_and_expr_type: (*mut PyTypeObject, *mut PyTypeObject), py_list_type: *mut PyTypeObject, + field_table: &FieldTable, result_list: &mut Vec<*mut PyObject>, ) -> PyResult<()> { const BATCH: usize = 4; let mut stack = vec![node]; - let py_long_type = &raw mut ffi::PyLong_Type; let mut batch: [*mut PyObject; BATCH] = [std::ptr::null_mut(); BATCH]; while !stack.is_empty() { @@ -255,7 +370,7 @@ fn walk_node_unordered( current, base_ast_and_expr_type, py_list_type, - py_long_type, + field_table, &mut stack, ); } @@ -268,6 +383,8 @@ fn walk_node_unordered( thread_local! { static BASE_AST_TYPE_AND_EXPR: Cell> = const { Cell::new(None) }; + static AST_FIELD_TABLE: RefCell>> = + const { RefCell::new(None) }; } /// Resolve `ast.AST` and `ast.expr` to their raw type pointers. Kept out @@ -292,6 +409,55 @@ fn resolve_base_types(py: Python) -> PyResult<(*mut PyTypeObject, *mut PyTypeObj }) } +/// Walk every subclass of `ast.AST` at first-use and record each type's +/// `len(_fields)`. The resulting direct-mapped table answers the hot-loop +/// lookup in one L1 load per node — no Python calls, no `_attributes` +/// scanning, no hashing. +#[inline(never)] +fn prebuild_field_table(py: Python<'_>) -> PyResult> { + let ast_module = py.import("ast")?; + let ast_class = ast_module.getattr("AST")?.cast_into::()?; + + let mut table = Box::new(FieldTable::new()); + let mut stack: Vec> = vec![ast_class]; + while let Some(t) = stack.pop() { + let n_fields = t + .getattr("_fields") + .ok() + .and_then(|f| f.len().ok()) + .unwrap_or(0); + // `_fields` tuples in the stdlib top out at 7 entries. Saturate + // for safety so a rogue subclass with a huge `_fields` tuple + // can't break the u8 encoding. + // Saturate to u8::MAX - 1: the table reserves n_fields+1 as the + // encoded value so we need headroom below 255. + let n = n_fields.min((u8::MAX - 1) as usize) as u8; + table.insert_ast(t.as_type_ptr(), n); + let subs = t.call_method0("__subclasses__")?; + for sub in subs.try_iter()? { + stack.push(sub?.cast_into::()?); + } + } + Ok(table) +} + +/// Run `body` with a `&FieldTable` pinning the prebuilt `_fields`-length +/// cache. The table is built on first use per thread and reused for all +/// subsequent walks on that thread. +#[inline(always)] +fn with_field_table( + py: Python<'_>, + body: impl FnOnce(&FieldTable) -> PyResult, +) -> PyResult { + AST_FIELD_TABLE.with(|cache| { + let mut borrow = cache.borrow_mut(); + if borrow.is_none() { + *borrow = Some(prebuild_field_table(py)?); + } + body(borrow.as_ref().unwrap()) + }) +} + /// Construct a Python list from a Vec of owned-reference pointers, going /// directly through the FFI `PyList_New` + `PyList_SET_ITEM` path. Avoids /// the per-item `Bound` allocation in `PyList::new(iter)`. @@ -319,15 +485,14 @@ fn vec_into_pylist<'py>(py: Python<'py>, items: &[*mut PyObject]) -> PyResult(py: Python<'py>, node: Bound<'py, PyAny>) -> PyResult> { - let mut result_list = Vec::new(); let base = resolve_base_types(py)?; - walk_node_dfs( - node.as_ptr(), - base, - PyList::type_object_raw(py), - &mut result_list, - )?; - vec_into_pylist(py, &result_list) + let py_list_type = PyList::type_object_raw(py); + let node_ptr = node.as_ptr(); + with_field_table(py, |table| { + let mut result_list = Vec::new(); + walk_node_dfs(node_ptr, base, py_list_type, table, &mut result_list)?; + vec_into_pylist(py, &result_list) + }) } /// Walk the AST rooted at `node` and return every descendant (including @@ -339,15 +504,14 @@ fn walk_dfs<'py>(py: Python<'py>, node: Bound<'py, PyAny>) -> PyResult(py: Python<'py>, node: Bound<'py, PyAny>) -> PyResult> { - let mut result_list = Vec::new(); let base = resolve_base_types(py)?; - walk_node_unordered( - node.as_ptr(), - base, - PyList::type_object_raw(py), - &mut result_list, - )?; - vec_into_pylist(py, &result_list) + let py_list_type = PyList::type_object_raw(py); + let node_ptr = node.as_ptr(); + with_field_table(py, |table| { + let mut result_list = Vec::new(); + walk_node_unordered(node_ptr, base, py_list_type, table, &mut result_list)?; + vec_into_pylist(py, &result_list) + }) } static DEPRECATED_WALK_WARNED: AtomicBool = AtomicBool::new(false); @@ -377,15 +541,14 @@ fn walk<'py>(py: Python<'py>, node: Bound<'py, PyAny>) -> PyResult(py: Python, node: Bound<'py, PyAny>) -> PyResult { - let mut result_list = Vec::new(); let base = resolve_base_types(py)?; - walk_node_dfs( - node.as_ptr(), - base, - PyList::type_object_raw(py), - &mut result_list, - )?; - Ok(result_list.len()) + let py_list_type = PyList::type_object_raw(py); + let node_ptr = node.as_ptr(); + with_field_table(py, |table| { + let mut result_list = Vec::new(); + walk_node_dfs(node_ptr, base, py_list_type, table, &mut result_list)?; + Ok(result_list.len()) + }) } #[pymodule] @@ -410,7 +573,8 @@ mod tests { Python::attach(|py| { let dict = PyDict::new(py); let dict_ptr = dict.as_ptr() as *mut pyo3::ffi::PyDictObject; - let values = unsafe { ReverseDictValuesIter::new(dict_ptr) }.collect::>(); + let values = + unsafe { ReverseDictValuesIter::new(dict_ptr, usize::MAX) }.collect::>(); assert_eq!(values.len(), 0); }); } @@ -426,7 +590,8 @@ mod tests { dict.set_item("c", 3).unwrap(); let dict_ptr = dict.as_ptr() as *mut pyo3::ffi::PyDictObject; - let values = unsafe { ReverseDictValuesIter::new(dict_ptr) }.collect::>(); + let values = + unsafe { ReverseDictValuesIter::new(dict_ptr, usize::MAX) }.collect::>(); assert_eq!(values.len(), 3); }); } diff --git a/tests/test_coherency.py b/tests/test_coherency.py index 6f27fed..a29aed4 100644 --- a/tests/test_coherency.py +++ b/tests/test_coherency.py @@ -113,6 +113,18 @@ def __init__(self, value: T) -> None: type Alias[T] = list[T] """), + # Wider / deeper sources to exercise traversal on bigger trees. The + # smaller sources above catch most syntactic edge cases; these ones + # are here to surface scale-sensitive bugs — unbounded stack growth, + # quadratic behaviour, cycle-like regressions — that only appear once + # the node count is large enough. + "many_small_statements": "\n".join(f"x{i} = {i}" for i in range(500)), + "wide_function_body": textwrap.dedent(""" + def big(): + """) + "\n".join(f" y{i} = f({i}) + g({i})" for i in range(300)), + "deeply_nested_calls": "f(" * 100 + "x" + ")" * 100, + "deeply_nested_attributes": "a" + ".b" * 200, + "deeply_nested_binops": " + ".join(f"x{i}" for i in range(300)), } @@ -249,3 +261,61 @@ def test_matches_stdlib_for_sizable_real_file(): expected = _multiset(ast.walk(tree)) assert _multiset(walk_dfs(tree)) == expected assert _multiset(walk_unordered(tree)) == expected + + +@pytest.mark.parametrize("walk_fn", [walk_dfs, walk_unordered]) +def test_parent_back_references_do_not_inflate_walk(walk_fn, tree: ast.AST): + """Decorating nodes with `.parent` back-references is a common AST- + transformer pattern. `ast.walk` never follows non-`_fields` keys, so + the back-refs are invisible to it — fast_walk must match. Under the + old "scan every __dict__ value" implementation this would either + infinite-loop or silently multiply the result; today it's guarded by + the per-type `_fields`-length bound.""" + for node in ast.walk(tree): + for child in ast.iter_child_nodes(node): + child.parent = node # pyright: ignore[reportAttributeAccessIssue] + + assert _multiset(walk_fn(tree)) == _multiset(ast.walk(tree)) + + +@pytest.mark.parametrize("walk_fn", [walk_dfs, walk_unordered]) +def test_self_reference_does_not_inflate_walk(walk_fn): + """Minimal cycle: a node attaches a reference to itself via a non- + `_fields` attribute. Must terminate and not double-count.""" + tree = ast.parse("x = 1") + tree.body[0].self_ref = tree.body[0] # pyright: ignore[reportAttributeAccessIssue] + + assert _multiset(walk_fn(tree)) == _multiset(ast.walk(tree)) + + +@pytest.mark.parametrize("walk_fn", [walk_dfs, walk_unordered]) +def test_non_fields_ast_attribute_is_ignored(walk_fn): + """A user-attached AST reference outside `_fields` must not leak + into the walk result — not even when there's no cycle. Keeps us + strictly equivalent to `ast.walk`.""" + tree = ast.parse("x = 1") + sibling = ast.parse("y = 2").body[0] + tree.body[0].extra = sibling # pyright: ignore[reportAttributeAccessIssue] + + result = walk_fn(tree) + assert _multiset(result) == _multiset(ast.walk(tree)) + assert id(sibling) not in {id(n) for n in result} + + +def test_parent_back_references_on_real_file(): + """Scale check for the cycle fix: every node in a real module gets + a `.parent` back-reference, and both walks must still match + `ast.walk` exactly. Under a regression this would exhaust memory.""" + import difflib + from pathlib import Path + + tree = ast.parse(Path(difflib.__file__).read_text()) + for node in ast.walk(tree): + for child in ast.iter_child_nodes(node): + child.parent = node # pyright: ignore[reportAttributeAccessIssue] + + expected = _multiset(ast.walk(tree)) + assert _multiset(walk_dfs(tree)) == expected + assert _multiset(walk_unordered(tree)) == expected + + diff --git a/tests/test_refcount.py b/tests/test_refcount.py index 8526e60..dfaa048 100644 --- a/tests/test_refcount.py +++ b/tests/test_refcount.py @@ -40,6 +40,10 @@ def test_refcount_neutral_after_walks(): tree = ast.parse(SOURCE) sample = list(ast.walk(tree)) + # Drain any prior-test garbage (shared AST singletons like ast.Load() + # have refcounts that reflect every live AST tree in the process, + # so leftover trees from earlier tests would inflate `before`). + gc.collect() before = [sys.getrefcount(n) for n in sample] for _ in range(1000): fast_walk(tree) # result is dropped immediately as an expr statement