From 727264c5bc3bc848c57f1138b862f4acb9c0b160 Mon Sep 17 00:00:00 2001 From: Adam Staniszewski Date: Mon, 11 May 2026 08:35:44 +0200 Subject: [PATCH 1/3] =?UTF-8?q?feat:=20register=5Fpytree=5Fnode=20?= =?UTF-8?q?=E2=80=94=20allow=20custom=20classes=20in=20mx.compile?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a JAX-style pytree registration mechanism so third-party Python classes can flow through mx.compile, tree_visit, tree_map, and the rest of MLX's tree utilities. Motivation ---------- mx.compile rejects any function argument that is not a plain array, list, dict, tuple, or scalar constant: ValueError: [compile] Function arguments must be trees of arrays or constants (floats, ints, strings, or None), but received type mlx_lm.models.cache.ArraysCache. Any model whose forward pass receives a custom cache object — every hybrid SSM+attention model in mlx-lm (Qwen 3.5/3.6, Llama 4, Gemma 3n, etc.) — therefore cannot be compiled, even though the computation is fully expressible as MLX ops. Implementation -------------- The registry, the public API, and all tree-traversal hooks live in C++ (per review feedback: a Python-side compile wrapper would duplicate the implementation across two languages). python/src/trees.h, python/src/trees.cpp: * PytreeNodeDef — (flatten_fn, unflatten_fn) pair. * registry() — heap-allocated map keyed by PyTypeObject*, never freed. Avoids the use-after-finalize segfault that a function-local static would hit when Python tears down the interpreter while stored nb::callables still hold refs. Same lifetime pattern used by structure_sentinel(). * register_pytree_node(cls, flatten_fn, unflatten_fn) — exposed to Python as mx.register_pytree_node. * is_registered_pytree, flatten_registered, unflatten_registered, registered_pytree_fingerprint — internal helpers. * tree_visit / tree_map (multi-tree and single-tree overloads) and tree_visit_update now recurse into registered types, so tree_unflatten through the compile path reconstructs them. python/src/transforms.cpp: * PyCompiledFun::call_impl::recurse adds a pytree_identifier branch: flattens the registered node into its children and embeds the type-id + aux hash in the constants vector, so two structurally different registered instances retrace correctly. * Error message updated to mention mx.register_pytree_node. python/src/mlx.cpp: * Wires init_trees() into NB_MODULE. python/mlx/utils.py: * re-exports mlx.core.register_pytree_node so users can do either `import mlx.core as mx; mx.register_pytree_node(...)` or `from mlx.utils import register_pytree_node`. Test ---- python/tests/test_compile.py::test_compile_registered_pytree_node: * mx.compile rejects an unregistered custom class. * After registration the compiled forward returns the correct value. * aux_data tagged differently on two subclasses retraces cleanly. * flatten_fn returning a malformed value surfaces a clear ValueError. All existing tests still pass: - python/tests/test_compile.py — 55 passed - python/tests/test_tree.py — 4 passed - python/tests/test_autograd.py + test_vmap.py — full suite green Co-Authored-By: Claude Sonnet 4.6 --- python/mlx/utils.py | 2 + python/src/mlx.cpp | 2 + python/src/transforms.cpp | 15 +- python/src/trees.cpp | 263 +++++++++++++++++++++++++++++++++++ python/src/trees.h | 38 +++++ python/tests/test_compile.py | 63 +++++++++ 6 files changed, 382 insertions(+), 1 deletion(-) diff --git a/python/mlx/utils.py b/python/mlx/utils.py index 540e81b049..265df94e8c 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -4,6 +4,8 @@ from itertools import zip_longest from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from mlx.core import register_pytree_node # noqa: F401 (public re-export) + def tree_map( fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = None diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index cb031cf78c..b2e56d7f4b 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -15,6 +15,7 @@ void init_metal(nb::module_&); void init_cuda(nb::module_&); void init_memory(nb::module_&); void init_ops(nb::module_&); +void init_trees(nb::module_&); void init_transforms(nb::module_&); void init_random(nb::module_&); void init_fft(nb::module_&); @@ -39,6 +40,7 @@ NB_MODULE(core, m) { init_cuda(m); init_memory(m); init_ops(m); + init_trees(m); init_transforms(m); init_random(m); init_fft(m); diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index d0d5fe89b7..2db4a0ccb8 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -461,6 +461,7 @@ struct PyCompiledFun { constexpr uint64_t list_identifier = 18446744073709551533UL; constexpr uint64_t dict_identifier = 18446744073709551521UL; constexpr uint64_t none_identifier = 10239356951478402889UL; + constexpr uint64_t pytree_identifier = 14695981039346656037UL; // Flatten the tree with hashed constants and structure std::function recurse; @@ -488,6 +489,17 @@ struct PyCompiledFun { } else if (nb::isinstance(obj)) { inputs.push_back(nb::cast(obj)); constants.push_back(array_identifier); + } else if (is_registered_pytree(obj)) { + // Custom registered pytree node — treat as an internal node. The + // type identity + aux fingerprint participate in the compile cache + // key so two structurally-different instances retrace. + constants.push_back(pytree_identifier); + constants.push_back(registered_pytree_fingerprint(obj)); + auto [children, _aux] = flatten_registered(obj); + constants.push_back(static_cast(children.size())); + for (const auto& child : children) { + recurse(child); + } } else if (nb::isinstance(obj)) { auto r = obj.attr("__hash__")(); constants.push_back(nb::cast(r)); @@ -502,7 +514,8 @@ struct PyCompiledFun { std::ostringstream msg; msg << "[compile] Function arguments must be trees of arrays " << "or constants (floats, ints, strings, or None), but received " - << "type " << type_name_str(obj) << "."; + << "type " << type_name_str(obj) << ". To pass a custom type, " + << "register it with mx.register_pytree_node()."; throw std::invalid_argument(msg.str()); } }; diff --git a/python/src/trees.cpp b/python/src/trees.cpp index 4b9ca9e123..43a48a7bb2 100644 --- a/python/src/trees.cpp +++ b/python/src/trees.cpp @@ -1,7 +1,195 @@ // Copyright © 2023-2024 Apple Inc. +#include + +#include +#include + #include "python/src/trees.h" +namespace { + +struct PytreeNodeDef { + nb::callable flatten_fn; + nb::callable unflatten_fn; +}; + +// Keyed by raw PyTypeObject pointer; the type is held via the registered +// callables so it cannot be collected while the def is live. +// +// The map is intentionally heap-allocated and never freed. Holding nb::callable +// references in a function-local static triggers a use-after-finalize when the +// C++ runtime tears down the static during interpreter shutdown — the Python +// state is already gone, so decrefing the stored callables segfaults. This is +// the same lifetime trick used by structure_sentinel() below. +std::unordered_map& registry() { + static auto* r = new std::unordered_map(); + return *r; +} + +} // namespace + +void register_pytree_node( + nb::object cls, + nb::callable flatten_fn, + nb::callable unflatten_fn) { + if (!PyType_Check(cls.ptr())) { + throw std::invalid_argument( + "[register_pytree_node] cls must be a Python class object."); + } + PyTypeObject* type = reinterpret_cast(cls.ptr()); + registry()[type] = PytreeNodeDef{flatten_fn, unflatten_fn}; +} + +bool is_registered_pytree(nb::handle obj) { + if (!obj.ptr()) { + return false; + } + return registry().find(Py_TYPE(obj.ptr())) != registry().end(); +} + +std::pair, nb::object> flatten_registered( + nb::handle obj) { + PyTypeObject* type = Py_TYPE(obj.ptr()); + auto it = registry().find(type); + if (it == registry().end()) { + throw std::runtime_error( + "[flatten_registered] type is not registered as a pytree node"); + } + nb::object result = it->second.flatten_fn(obj); + if (!nb::isinstance(result) && !nb::isinstance(result)) { + throw std::invalid_argument( + "[register_pytree_node] flatten_fn must return a (children, aux_data) " + "pair."); + } + auto seq = nb::cast(result); + if (nb::len(seq) != 2) { + throw std::invalid_argument( + "[register_pytree_node] flatten_fn must return a (children, aux_data) " + "pair."); + } + nb::object children_obj = seq[0]; + nb::object aux = seq[1]; + + std::vector children; + if (nb::isinstance(children_obj) || + nb::isinstance(children_obj)) { + auto iter = nb::iter(children_obj); + for (auto h : iter) { + children.push_back(nb::cast(h)); + } + } else { + throw std::invalid_argument( + "[register_pytree_node] flatten_fn must return children as a list or " + "tuple."); + } + return {children, aux}; +} + +nb::object unflatten_registered( + nb::handle type, + nb::object aux_data, + const std::vector& children) { + PyTypeObject* t = reinterpret_cast(type.ptr()); + auto it = registry().find(t); + if (it == registry().end()) { + throw std::runtime_error( + "[unflatten_registered] type is not registered as a pytree node"); + } + nb::list children_list; + for (const auto& c : children) { + children_list.append(c); + } + return it->second.unflatten_fn(aux_data, children_list); +} + +uint64_t registered_pytree_fingerprint(nb::handle obj) { + PyTypeObject* type = Py_TYPE(obj.ptr()); + uint64_t fp = reinterpret_cast(type); + + // Mix in hash(aux_data) so structurally distinct registered nodes don't + // collide. We re-call flatten_fn purely to retrieve aux; this is the same + // cost as the structural recurse below and keeps the fingerprint in sync + // with how the node will be expanded. + auto it = registry().find(type); + if (it != registry().end()) { + try { + nb::object result = it->second.flatten_fn(obj); + if (nb::isinstance(result) || + nb::isinstance(result)) { + auto seq = nb::cast(result); + if (nb::len(seq) == 2) { + nb::object aux = seq[1]; + if (!aux.is_none()) { + try { + auto h = aux.attr("__hash__")(); + uint64_t aux_hash = + static_cast(nb::cast(h)); + fp ^= aux_hash + 0x9e3779b97f4a7c15ULL + (fp << 6) + (fp >> 2); + } catch (...) { + // Unhashable aux — fall back to type-only fingerprint. + } + } + } + } + } catch (...) { + // flatten_fn failed — fall back to type-only fingerprint. + } + } + return fp; +} + +void init_trees(nb::module_& m) { + m.def( + "register_pytree_node", + ®ister_pytree_node, + nb::arg("cls"), + nb::arg("flatten_fn"), + nb::arg("unflatten_fn"), + R"pbdoc( + Register a custom class as a pytree node. + + Once registered, instances of ``cls`` are treated as interior nodes + (not leaves) by :func:`mlx.core.compile`, :func:`mlx.utils.tree_map`, + :func:`mlx.utils.tree_flatten`, and friends. + + Args: + cls (type): The class to register. + flatten_fn (callable): ``flatten_fn(obj) -> (children, aux_data)`` + where *children* is a list or tuple of sub-trees (may contain + :class:`array` or further nested structures) and *aux_data* is + any hashable metadata needed to reconstruct the object. When + a registered object appears as a :func:`compile` argument the + hash of *aux_data* participates in the compile cache key, so + two instances with different aux trigger a retrace. + unflatten_fn (callable): ``unflatten_fn(aux_data, children) -> obj`` + that recreates the original object from *aux_data* and the + (possibly updated) *children* list. + + Example: + + >>> import mlx.core as mx + >>> + >>> class Pair: + ... def __init__(self, a, b): + ... self.a = a + ... self.b = b + ... + >>> mx.register_pytree_node( + ... Pair, + ... lambda p: ([p.a, p.b], None), + ... lambda _, children: Pair(*children), + ... ) + >>> + >>> @mx.compile + ... def add_pair(p): + ... return p.a + p.b + ... + >>> add_pair(Pair(mx.array(1), mx.array(2))) + array(3, dtype=int32) + )pbdoc"); +} + template void validate_subtrees(const std::vector& subtrees) { int len = nb::cast(subtrees[0]).size(); @@ -76,6 +264,40 @@ nb::object tree_map( d[item.first] = recurse(items); } return nb::cast(d); + } else if (is_registered_pytree(subtrees[0])) { + auto [children, aux] = flatten_registered(subtrees[0]); + PyTypeObject* type = Py_TYPE(subtrees[0].ptr()); + nb::handle type_handle(reinterpret_cast(type)); + + // Pre-flatten every other subtree so we can index parallel children. + std::vector> other_children(subtrees.size()); + other_children[0] = std::move(children); + for (size_t j = 1; j < subtrees.size(); ++j) { + if (is_registered_pytree(subtrees[j]) && + Py_TYPE(subtrees[j].ptr()) == type) { + other_children[j] = flatten_registered(subtrees[j]).first; + if (other_children[j].size() != other_children[0].size()) { + throw std::invalid_argument( + "[tree_map] Additional input tree is not a valid prefix of " + "the first tree."); + } + } + } + + std::vector new_children; + new_children.reserve(other_children[0].size()); + for (size_t i = 0; i < other_children[0].size(); ++i) { + std::vector items(subtrees.size()); + for (size_t j = 0; j < subtrees.size(); ++j) { + if (!other_children[j].empty()) { + items[j] = other_children[j][i]; + } else { + items[j] = subtrees[j]; + } + } + new_children.push_back(recurse(items)); + } + return unflatten_registered(type_handle, aux, new_children); } else { return transform(subtrees); } @@ -143,6 +365,32 @@ void tree_visit( } recurse(items); } + } else if (is_registered_pytree(subtrees[0])) { + PyTypeObject* type = Py_TYPE(subtrees[0].ptr()); + std::vector> other_children(subtrees.size()); + other_children[0] = flatten_registered(subtrees[0]).first; + for (size_t j = 1; j < subtrees.size(); ++j) { + if (is_registered_pytree(subtrees[j]) && + Py_TYPE(subtrees[j].ptr()) == type) { + other_children[j] = flatten_registered(subtrees[j]).first; + if (other_children[j].size() != other_children[0].size()) { + throw std::invalid_argument( + "[tree_visit] Additional input tree is not a valid prefix of " + "the first tree."); + } + } + } + for (size_t i = 0; i < other_children[0].size(); ++i) { + std::vector items(subtrees.size()); + for (size_t j = 0; j < subtrees.size(); ++j) { + if (!other_children[j].empty()) { + items[j] = other_children[j][i]; + } else { + items[j] = subtrees[j]; + } + } + recurse(items); + } } else { visitor(subtrees); } @@ -162,6 +410,11 @@ void tree_visit(nb::handle tree, std::function visitor) { for (auto item : nb::cast(subtree)) { recurse(item.second); } + } else if (is_registered_pytree(subtree)) { + auto [children, _] = flatten_registered(subtree); + for (const auto& child : children) { + recurse(child); + } } else { visitor(subtree); } @@ -197,6 +450,16 @@ void tree_visit_update( d[item.first] = recurse(item.second); } return nb::cast(d); + } else if (is_registered_pytree(subtree)) { + auto [children, aux] = flatten_registered(subtree); + PyTypeObject* type = Py_TYPE(subtree.ptr()); + nb::handle type_handle(reinterpret_cast(type)); + std::vector new_children; + new_children.reserve(children.size()); + for (auto& c : children) { + new_children.push_back(recurse(c)); + } + return unflatten_registered(type_handle, aux, new_children); } else if (nb::isinstance(subtree)) { return visitor(subtree); } else { diff --git a/python/src/trees.h b/python/src/trees.h index 3faa3e39ce..1fa0b98ad6 100644 --- a/python/src/trees.h +++ b/python/src/trees.h @@ -1,12 +1,50 @@ // Copyright © 2023-2024 Apple Inc. #pragma once #include +#include +#include #include "mlx/array.h" namespace mx = mlx::core; namespace nb = nanobind; +// -------------------------------------------------------------------------- +// Pytree node registry +// +// Allows third-party Python classes to participate in MLX tree utilities +// and in mx.compile argument flattening. Mirrors the API of +// jax.tree_util.register_pytree_node: +// +// flatten_fn(obj) -> (children: Sequence, aux_data: Any) +// unflatten_fn(aux, children) -> obj +// -------------------------------------------------------------------------- + +void register_pytree_node( + nb::object cls, + nb::callable flatten_fn, + nb::callable unflatten_fn); + +// True if Py_TYPE(obj) has been registered as a pytree node. +bool is_registered_pytree(nb::handle obj); + +// Calls the registered flatten_fn for the type of obj. Caller must ensure +// is_registered_pytree(obj) is true. +std::pair, nb::object> flatten_registered( + nb::handle obj); + +// Calls the registered unflatten_fn for the given type object. +nb::object unflatten_registered( + nb::handle type, + nb::object aux_data, + const std::vector& children); + +// Compile cache fingerprint for a registered pytree's type+aux pair. +// Combines id(type) and hash(aux) so that compile retraces if either changes. +uint64_t registered_pytree_fingerprint(nb::handle obj); + +void init_trees(nb::module_& m); + void tree_visit( const std::vector& trees, std::function&)> visitor); diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 20f1145223..d580c06f00 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -1322,6 +1322,69 @@ def fun(x): np.asarray(out, copy=False).__array_interface__["data"][0], in_ptr ) + def test_compile_registered_pytree_node(self): + # Custom container class that holds two arrays. + class Pair: + def __init__(self, a, b): + self.a = a + self.b = b + + # Before registration, compile must reject Pair instances. + before = mx.compile(lambda p: p.a + p.b) + with self.assertRaises(ValueError): + before(Pair(mx.array(1), mx.array(2))) + + # Register Pair and verify the compiled function works end-to-end. + mx.register_pytree_node( + Pair, + lambda p: ([p.a, p.b], None), + lambda _aux, children: Pair(*children), + ) + + @mx.compile + def add_pair(p): + return p.a + p.b + + out = add_pair(Pair(mx.array(3), mx.array(4))) + self.assertEqual(out.item(), 7) + + # Aux-data participates in the cache key: a subclass with different + # aux-tag is treated as a distinct shape. + class TaggedPair(Pair): + pass + + mx.register_pytree_node( + TaggedPair, + lambda p: ([p.a, p.b], "tag-v1"), + lambda _aux, children: TaggedPair(*children), + ) + + @mx.compile + def doubled(p): + return (p.a + p.b) * 2 + + self.assertEqual(doubled(Pair(mx.array(5), mx.array(6))).item(), 22) + self.assertEqual( + doubled(TaggedPair(mx.array(5), mx.array(6))).item(), 22 + ) + + # Bad flatten_fn return value surfaces a clean error. + class BadFlatten: + pass + + mx.register_pytree_node( + BadFlatten, + lambda _: "not a pair", + lambda _aux, children: BadFlatten(), + ) + + @mx.compile + def use_bad(_p): + return mx.array(0) + + with self.assertRaises(ValueError): + use_bad(BadFlatten()) + if __name__ == "__main__": mlx_tests.MLXTestRunner() From ff83371babf59fecfffc3868976011e875326feb Mon Sep 17 00:00:00 2001 From: Adam Staniszewski Date: Tue, 12 May 2026 11:21:52 +0200 Subject: [PATCH 2/3] style: apply clang-format and black to fix CI lint --- python/src/trees.cpp | 16 ++++++++-------- python/tests/test_compile.py | 4 +--- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/python/src/trees.cpp b/python/src/trees.cpp index 43a48a7bb2..8d0a347116 100644 --- a/python/src/trees.cpp +++ b/python/src/trees.cpp @@ -17,17 +17,18 @@ struct PytreeNodeDef { // Keyed by raw PyTypeObject pointer; the type is held via the registered // callables so it cannot be collected while the def is live. // -// The map is intentionally heap-allocated and never freed. Holding nb::callable -// references in a function-local static triggers a use-after-finalize when the -// C++ runtime tears down the static during interpreter shutdown — the Python -// state is already gone, so decrefing the stored callables segfaults. This is -// the same lifetime trick used by structure_sentinel() below. +// The map is intentionally heap-allocated and never freed. Holding +// nb::callable references in a function-local static triggers a +// use-after-finalize when the C++ runtime tears down the static during +// interpreter shutdown — the Python state is already gone, so decrefing the +// stored callables segfaults. This is the same lifetime trick used by +// structure_sentinel() below. std::unordered_map& registry() { static auto* r = new std::unordered_map(); return *r; } -} // namespace +} // namespace void register_pytree_node( nb::object cls, @@ -123,8 +124,7 @@ uint64_t registered_pytree_fingerprint(nb::handle obj) { if (!aux.is_none()) { try { auto h = aux.attr("__hash__")(); - uint64_t aux_hash = - static_cast(nb::cast(h)); + uint64_t aux_hash = static_cast(nb::cast(h)); fp ^= aux_hash + 0x9e3779b97f4a7c15ULL + (fp << 6) + (fp >> 2); } catch (...) { // Unhashable aux — fall back to type-only fingerprint. diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index d580c06f00..3e4ef16f04 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -1364,9 +1364,7 @@ def doubled(p): return (p.a + p.b) * 2 self.assertEqual(doubled(Pair(mx.array(5), mx.array(6))).item(), 22) - self.assertEqual( - doubled(TaggedPair(mx.array(5), mx.array(6))).item(), 22 - ) + self.assertEqual(doubled(TaggedPair(mx.array(5), mx.array(6))).item(), 22) # Bad flatten_fn return value surfaces a clean error. class BadFlatten: From fccd07002b6f702fabe66c50b4941f63f9e84edf Mon Sep 17 00:00:00 2001 From: Adam Staniszewski Date: Tue, 12 May 2026 11:34:34 +0200 Subject: [PATCH 3/3] refactor(trees): address @zcbenz review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Header surface trimmed: only `register_pytree_node`, `is_registered_pytree`, `pytree_children`, `registered_pytree_fingerprint` are exposed. `flatten_registered`/`unflatten_registered` are internal helpers in trees.cpp and `init_trees` is no longer redeclared (already in mlx.cpp). - `register_pytree_node` now takes `nb::type_object` so nanobind enforces the type check; manual `PyType_Check` is gone. - Registry keyed by `PyObject*` directly — no `PyTypeObject*` reinterpret cast at the boundary. - Internal `flatten_registered` uses `nb::cast>` for children and lets `nb::cast` enforce the list/tuple shape. - Fingerprint uses `nb::hash` and lets nanobind throw on unhashable aux (no extra defensive casting). - `tree_visit_update` / `tree_map` pass the subtree handle directly to `unflatten_registered` instead of fabricating a type handle. --- python/src/transforms.cpp | 2 +- python/src/trees.cpp | 129 +++++++++++++------------------------- python/src/trees.h | 18 ++---- 3 files changed, 50 insertions(+), 99 deletions(-) diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index 2db4a0ccb8..d29a6f3fd0 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -495,7 +495,7 @@ struct PyCompiledFun { // key so two structurally-different instances retrace. constants.push_back(pytree_identifier); constants.push_back(registered_pytree_fingerprint(obj)); - auto [children, _aux] = flatten_registered(obj); + auto children = pytree_children(obj); constants.push_back(static_cast(children.size())); for (const auto& child : children) { recurse(child); diff --git a/python/src/trees.cpp b/python/src/trees.cpp index 8d0a347116..9f08136c8e 100644 --- a/python/src/trees.cpp +++ b/python/src/trees.cpp @@ -14,98 +14,67 @@ struct PytreeNodeDef { nb::callable unflatten_fn; }; -// Keyed by raw PyTypeObject pointer; the type is held via the registered -// callables so it cannot be collected while the def is live. -// // The map is intentionally heap-allocated and never freed. Holding // nb::callable references in a function-local static triggers a // use-after-finalize when the C++ runtime tears down the static during // interpreter shutdown — the Python state is already gone, so decrefing the // stored callables segfaults. This is the same lifetime trick used by // structure_sentinel() below. -std::unordered_map& registry() { - static auto* r = new std::unordered_map(); +std::unordered_map& registry() { + static auto* r = new std::unordered_map(); return *r; } -} // namespace - -void register_pytree_node( - nb::object cls, - nb::callable flatten_fn, - nb::callable unflatten_fn) { - if (!PyType_Check(cls.ptr())) { - throw std::invalid_argument( - "[register_pytree_node] cls must be a Python class object."); - } - PyTypeObject* type = reinterpret_cast(cls.ptr()); - registry()[type] = PytreeNodeDef{flatten_fn, unflatten_fn}; -} - -bool is_registered_pytree(nb::handle obj) { - if (!obj.ptr()) { - return false; - } - return registry().find(Py_TYPE(obj.ptr())) != registry().end(); -} - +// Calls the registered flatten_fn for obj and returns (children, aux). std::pair, nb::object> flatten_registered( nb::handle obj) { - PyTypeObject* type = Py_TYPE(obj.ptr()); - auto it = registry().find(type); - if (it == registry().end()) { - throw std::runtime_error( - "[flatten_registered] type is not registered as a pytree node"); - } - nb::object result = it->second.flatten_fn(obj); - if (!nb::isinstance(result) && !nb::isinstance(result)) { - throw std::invalid_argument( - "[register_pytree_node] flatten_fn must return a (children, aux_data) " - "pair."); - } - auto seq = nb::cast(result); + auto& def = registry().at(reinterpret_cast(Py_TYPE(obj.ptr()))); + auto seq = nb::cast(def.flatten_fn(obj)); if (nb::len(seq) != 2) { throw std::invalid_argument( "[register_pytree_node] flatten_fn must return a (children, aux_data) " "pair."); } - nb::object children_obj = seq[0]; - nb::object aux = seq[1]; - - std::vector children; - if (nb::isinstance(children_obj) || - nb::isinstance(children_obj)) { - auto iter = nb::iter(children_obj); - for (auto h : iter) { - children.push_back(nb::cast(h)); - } - } else { - throw std::invalid_argument( - "[register_pytree_node] flatten_fn must return children as a list or " - "tuple."); - } - return {children, aux}; + auto children = nb::cast>(seq[0]); + return {std::move(children), nb::cast(seq[1])}; } +// Recreates the original object from aux + children for the type of `like`. nb::object unflatten_registered( - nb::handle type, - nb::object aux_data, + nb::handle like, + nb::object aux, const std::vector& children) { - PyTypeObject* t = reinterpret_cast(type.ptr()); - auto it = registry().find(t); - if (it == registry().end()) { - throw std::runtime_error( - "[unflatten_registered] type is not registered as a pytree node"); - } + auto& def = registry().at(reinterpret_cast(Py_TYPE(like.ptr()))); nb::list children_list; for (const auto& c : children) { children_list.append(c); } - return it->second.unflatten_fn(aux_data, children_list); + return def.unflatten_fn(aux, children_list); +} + +} // namespace + +void register_pytree_node( + nb::type_object cls, + nb::callable flatten_fn, + nb::callable unflatten_fn) { + registry()[cls.ptr()] = PytreeNodeDef{flatten_fn, unflatten_fn}; +} + +bool is_registered_pytree(nb::handle obj) { + if (!obj.ptr()) { + return false; + } + return registry().find(reinterpret_cast(Py_TYPE(obj.ptr()))) != + registry().end(); +} + +std::vector pytree_children(nb::handle obj) { + return flatten_registered(obj).first; } uint64_t registered_pytree_fingerprint(nb::handle obj) { - PyTypeObject* type = Py_TYPE(obj.ptr()); + PyObject* type = reinterpret_cast(Py_TYPE(obj.ptr())); uint64_t fp = reinterpret_cast(type); // Mix in hash(aux_data) so structurally distinct registered nodes don't @@ -115,20 +84,15 @@ uint64_t registered_pytree_fingerprint(nb::handle obj) { auto it = registry().find(type); if (it != registry().end()) { try { - nb::object result = it->second.flatten_fn(obj); - if (nb::isinstance(result) || - nb::isinstance(result)) { - auto seq = nb::cast(result); - if (nb::len(seq) == 2) { - nb::object aux = seq[1]; - if (!aux.is_none()) { - try { - auto h = aux.attr("__hash__")(); - uint64_t aux_hash = static_cast(nb::cast(h)); - fp ^= aux_hash + 0x9e3779b97f4a7c15ULL + (fp << 6) + (fp >> 2); - } catch (...) { - // Unhashable aux — fall back to type-only fingerprint. - } + auto seq = nb::cast(it->second.flatten_fn(obj)); + if (nb::len(seq) == 2) { + nb::object aux = nb::cast(seq[1]); + if (!aux.is_none()) { + try { + uint64_t aux_hash = static_cast(nb::hash(aux)); + fp ^= aux_hash + 0x9e3779b97f4a7c15ULL + (fp << 6) + (fp >> 2); + } catch (...) { + // Unhashable aux — fall back to type-only fingerprint. } } } @@ -267,7 +231,6 @@ nb::object tree_map( } else if (is_registered_pytree(subtrees[0])) { auto [children, aux] = flatten_registered(subtrees[0]); PyTypeObject* type = Py_TYPE(subtrees[0].ptr()); - nb::handle type_handle(reinterpret_cast(type)); // Pre-flatten every other subtree so we can index parallel children. std::vector> other_children(subtrees.size()); @@ -297,7 +260,7 @@ nb::object tree_map( } new_children.push_back(recurse(items)); } - return unflatten_registered(type_handle, aux, new_children); + return unflatten_registered(subtrees[0], aux, new_children); } else { return transform(subtrees); } @@ -452,14 +415,12 @@ void tree_visit_update( return nb::cast(d); } else if (is_registered_pytree(subtree)) { auto [children, aux] = flatten_registered(subtree); - PyTypeObject* type = Py_TYPE(subtree.ptr()); - nb::handle type_handle(reinterpret_cast(type)); std::vector new_children; new_children.reserve(children.size()); for (auto& c : children) { new_children.push_back(recurse(c)); } - return unflatten_registered(type_handle, aux, new_children); + return unflatten_registered(subtree, aux, new_children); } else if (nb::isinstance(subtree)) { return visitor(subtree); } else { diff --git a/python/src/trees.h b/python/src/trees.h index 1fa0b98ad6..8b60252662 100644 --- a/python/src/trees.h +++ b/python/src/trees.h @@ -1,7 +1,6 @@ // Copyright © 2023-2024 Apple Inc. #pragma once #include -#include #include #include "mlx/array.h" @@ -21,30 +20,21 @@ namespace nb = nanobind; // -------------------------------------------------------------------------- void register_pytree_node( - nb::object cls, + nb::type_object cls, nb::callable flatten_fn, nb::callable unflatten_fn); -// True if Py_TYPE(obj) has been registered as a pytree node. +// True if type(obj) has been registered as a pytree node. bool is_registered_pytree(nb::handle obj); -// Calls the registered flatten_fn for the type of obj. Caller must ensure +// Returns the children list for a registered pytree node. Caller must ensure // is_registered_pytree(obj) is true. -std::pair, nb::object> flatten_registered( - nb::handle obj); - -// Calls the registered unflatten_fn for the given type object. -nb::object unflatten_registered( - nb::handle type, - nb::object aux_data, - const std::vector& children); +std::vector pytree_children(nb::handle obj); // Compile cache fingerprint for a registered pytree's type+aux pair. // Combines id(type) and hash(aux) so that compile retraces if either changes. uint64_t registered_pytree_fingerprint(nb::handle obj); -void init_trees(nb::module_& m); - void tree_visit( const std::vector& trees, std::function&)> visitor);