diff --git a/python/mlx/utils.py b/python/mlx/utils.py index 540e81b049..265df94e8c 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -4,6 +4,8 @@ from itertools import zip_longest from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from mlx.core import register_pytree_node # noqa: F401 (public re-export) + def tree_map( fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = None diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index cb031cf78c..b2e56d7f4b 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -15,6 +15,7 @@ void init_metal(nb::module_&); void init_cuda(nb::module_&); void init_memory(nb::module_&); void init_ops(nb::module_&); +void init_trees(nb::module_&); void init_transforms(nb::module_&); void init_random(nb::module_&); void init_fft(nb::module_&); @@ -39,6 +40,7 @@ NB_MODULE(core, m) { init_cuda(m); init_memory(m); init_ops(m); + init_trees(m); init_transforms(m); init_random(m); init_fft(m); diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index d0d5fe89b7..d29a6f3fd0 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -461,6 +461,7 @@ struct PyCompiledFun { constexpr uint64_t list_identifier = 18446744073709551533UL; constexpr uint64_t dict_identifier = 18446744073709551521UL; constexpr uint64_t none_identifier = 10239356951478402889UL; + constexpr uint64_t pytree_identifier = 14695981039346656037UL; // Flatten the tree with hashed constants and structure std::function recurse; @@ -488,6 +489,17 @@ struct PyCompiledFun { } else if (nb::isinstance(obj)) { inputs.push_back(nb::cast(obj)); constants.push_back(array_identifier); + } else if (is_registered_pytree(obj)) { + // Custom registered pytree node — treat as an internal node. The + // type identity + aux fingerprint participate in the compile cache + // key so two structurally-different instances retrace. + constants.push_back(pytree_identifier); + constants.push_back(registered_pytree_fingerprint(obj)); + auto children = pytree_children(obj); + constants.push_back(static_cast(children.size())); + for (const auto& child : children) { + recurse(child); + } } else if (nb::isinstance(obj)) { auto r = obj.attr("__hash__")(); constants.push_back(nb::cast(r)); @@ -502,7 +514,8 @@ struct PyCompiledFun { std::ostringstream msg; msg << "[compile] Function arguments must be trees of arrays " << "or constants (floats, ints, strings, or None), but received " - << "type " << type_name_str(obj) << "."; + << "type " << type_name_str(obj) << ". To pass a custom type, " + << "register it with mx.register_pytree_node()."; throw std::invalid_argument(msg.str()); } }; diff --git a/python/src/trees.cpp b/python/src/trees.cpp index 4b9ca9e123..9f08136c8e 100644 --- a/python/src/trees.cpp +++ b/python/src/trees.cpp @@ -1,7 +1,159 @@ // Copyright © 2023-2024 Apple Inc. +#include + +#include +#include + #include "python/src/trees.h" +namespace { + +struct PytreeNodeDef { + nb::callable flatten_fn; + nb::callable unflatten_fn; +}; + +// The map is intentionally heap-allocated and never freed. Holding +// nb::callable references in a function-local static triggers a +// use-after-finalize when the C++ runtime tears down the static during +// interpreter shutdown — the Python state is already gone, so decrefing the +// stored callables segfaults. This is the same lifetime trick used by +// structure_sentinel() below. +std::unordered_map& registry() { + static auto* r = new std::unordered_map(); + return *r; +} + +// Calls the registered flatten_fn for obj and returns (children, aux). +std::pair, nb::object> flatten_registered( + nb::handle obj) { + auto& def = registry().at(reinterpret_cast(Py_TYPE(obj.ptr()))); + auto seq = nb::cast(def.flatten_fn(obj)); + if (nb::len(seq) != 2) { + throw std::invalid_argument( + "[register_pytree_node] flatten_fn must return a (children, aux_data) " + "pair."); + } + auto children = nb::cast>(seq[0]); + return {std::move(children), nb::cast(seq[1])}; +} + +// Recreates the original object from aux + children for the type of `like`. +nb::object unflatten_registered( + nb::handle like, + nb::object aux, + const std::vector& children) { + auto& def = registry().at(reinterpret_cast(Py_TYPE(like.ptr()))); + nb::list children_list; + for (const auto& c : children) { + children_list.append(c); + } + return def.unflatten_fn(aux, children_list); +} + +} // namespace + +void register_pytree_node( + nb::type_object cls, + nb::callable flatten_fn, + nb::callable unflatten_fn) { + registry()[cls.ptr()] = PytreeNodeDef{flatten_fn, unflatten_fn}; +} + +bool is_registered_pytree(nb::handle obj) { + if (!obj.ptr()) { + return false; + } + return registry().find(reinterpret_cast(Py_TYPE(obj.ptr()))) != + registry().end(); +} + +std::vector pytree_children(nb::handle obj) { + return flatten_registered(obj).first; +} + +uint64_t registered_pytree_fingerprint(nb::handle obj) { + PyObject* type = reinterpret_cast(Py_TYPE(obj.ptr())); + uint64_t fp = reinterpret_cast(type); + + // Mix in hash(aux_data) so structurally distinct registered nodes don't + // collide. We re-call flatten_fn purely to retrieve aux; this is the same + // cost as the structural recurse below and keeps the fingerprint in sync + // with how the node will be expanded. + auto it = registry().find(type); + if (it != registry().end()) { + try { + auto seq = nb::cast(it->second.flatten_fn(obj)); + if (nb::len(seq) == 2) { + nb::object aux = nb::cast(seq[1]); + if (!aux.is_none()) { + try { + uint64_t aux_hash = static_cast(nb::hash(aux)); + fp ^= aux_hash + 0x9e3779b97f4a7c15ULL + (fp << 6) + (fp >> 2); + } catch (...) { + // Unhashable aux — fall back to type-only fingerprint. + } + } + } + } catch (...) { + // flatten_fn failed — fall back to type-only fingerprint. + } + } + return fp; +} + +void init_trees(nb::module_& m) { + m.def( + "register_pytree_node", + ®ister_pytree_node, + nb::arg("cls"), + nb::arg("flatten_fn"), + nb::arg("unflatten_fn"), + R"pbdoc( + Register a custom class as a pytree node. + + Once registered, instances of ``cls`` are treated as interior nodes + (not leaves) by :func:`mlx.core.compile`, :func:`mlx.utils.tree_map`, + :func:`mlx.utils.tree_flatten`, and friends. + + Args: + cls (type): The class to register. + flatten_fn (callable): ``flatten_fn(obj) -> (children, aux_data)`` + where *children* is a list or tuple of sub-trees (may contain + :class:`array` or further nested structures) and *aux_data* is + any hashable metadata needed to reconstruct the object. When + a registered object appears as a :func:`compile` argument the + hash of *aux_data* participates in the compile cache key, so + two instances with different aux trigger a retrace. + unflatten_fn (callable): ``unflatten_fn(aux_data, children) -> obj`` + that recreates the original object from *aux_data* and the + (possibly updated) *children* list. + + Example: + + >>> import mlx.core as mx + >>> + >>> class Pair: + ... def __init__(self, a, b): + ... self.a = a + ... self.b = b + ... + >>> mx.register_pytree_node( + ... Pair, + ... lambda p: ([p.a, p.b], None), + ... lambda _, children: Pair(*children), + ... ) + >>> + >>> @mx.compile + ... def add_pair(p): + ... return p.a + p.b + ... + >>> add_pair(Pair(mx.array(1), mx.array(2))) + array(3, dtype=int32) + )pbdoc"); +} + template void validate_subtrees(const std::vector& subtrees) { int len = nb::cast(subtrees[0]).size(); @@ -76,6 +228,39 @@ nb::object tree_map( d[item.first] = recurse(items); } return nb::cast(d); + } else if (is_registered_pytree(subtrees[0])) { + auto [children, aux] = flatten_registered(subtrees[0]); + PyTypeObject* type = Py_TYPE(subtrees[0].ptr()); + + // Pre-flatten every other subtree so we can index parallel children. + std::vector> other_children(subtrees.size()); + other_children[0] = std::move(children); + for (size_t j = 1; j < subtrees.size(); ++j) { + if (is_registered_pytree(subtrees[j]) && + Py_TYPE(subtrees[j].ptr()) == type) { + other_children[j] = flatten_registered(subtrees[j]).first; + if (other_children[j].size() != other_children[0].size()) { + throw std::invalid_argument( + "[tree_map] Additional input tree is not a valid prefix of " + "the first tree."); + } + } + } + + std::vector new_children; + new_children.reserve(other_children[0].size()); + for (size_t i = 0; i < other_children[0].size(); ++i) { + std::vector items(subtrees.size()); + for (size_t j = 0; j < subtrees.size(); ++j) { + if (!other_children[j].empty()) { + items[j] = other_children[j][i]; + } else { + items[j] = subtrees[j]; + } + } + new_children.push_back(recurse(items)); + } + return unflatten_registered(subtrees[0], aux, new_children); } else { return transform(subtrees); } @@ -143,6 +328,32 @@ void tree_visit( } recurse(items); } + } else if (is_registered_pytree(subtrees[0])) { + PyTypeObject* type = Py_TYPE(subtrees[0].ptr()); + std::vector> other_children(subtrees.size()); + other_children[0] = flatten_registered(subtrees[0]).first; + for (size_t j = 1; j < subtrees.size(); ++j) { + if (is_registered_pytree(subtrees[j]) && + Py_TYPE(subtrees[j].ptr()) == type) { + other_children[j] = flatten_registered(subtrees[j]).first; + if (other_children[j].size() != other_children[0].size()) { + throw std::invalid_argument( + "[tree_visit] Additional input tree is not a valid prefix of " + "the first tree."); + } + } + } + for (size_t i = 0; i < other_children[0].size(); ++i) { + std::vector items(subtrees.size()); + for (size_t j = 0; j < subtrees.size(); ++j) { + if (!other_children[j].empty()) { + items[j] = other_children[j][i]; + } else { + items[j] = subtrees[j]; + } + } + recurse(items); + } } else { visitor(subtrees); } @@ -162,6 +373,11 @@ void tree_visit(nb::handle tree, std::function visitor) { for (auto item : nb::cast(subtree)) { recurse(item.second); } + } else if (is_registered_pytree(subtree)) { + auto [children, _] = flatten_registered(subtree); + for (const auto& child : children) { + recurse(child); + } } else { visitor(subtree); } @@ -197,6 +413,14 @@ void tree_visit_update( d[item.first] = recurse(item.second); } return nb::cast(d); + } else if (is_registered_pytree(subtree)) { + auto [children, aux] = flatten_registered(subtree); + std::vector new_children; + new_children.reserve(children.size()); + for (auto& c : children) { + new_children.push_back(recurse(c)); + } + return unflatten_registered(subtree, aux, new_children); } else if (nb::isinstance(subtree)) { return visitor(subtree); } else { diff --git a/python/src/trees.h b/python/src/trees.h index 3faa3e39ce..8b60252662 100644 --- a/python/src/trees.h +++ b/python/src/trees.h @@ -1,12 +1,40 @@ // Copyright © 2023-2024 Apple Inc. #pragma once #include +#include #include "mlx/array.h" namespace mx = mlx::core; namespace nb = nanobind; +// -------------------------------------------------------------------------- +// Pytree node registry +// +// Allows third-party Python classes to participate in MLX tree utilities +// and in mx.compile argument flattening. Mirrors the API of +// jax.tree_util.register_pytree_node: +// +// flatten_fn(obj) -> (children: Sequence, aux_data: Any) +// unflatten_fn(aux, children) -> obj +// -------------------------------------------------------------------------- + +void register_pytree_node( + nb::type_object cls, + nb::callable flatten_fn, + nb::callable unflatten_fn); + +// True if type(obj) has been registered as a pytree node. +bool is_registered_pytree(nb::handle obj); + +// Returns the children list for a registered pytree node. Caller must ensure +// is_registered_pytree(obj) is true. +std::vector pytree_children(nb::handle obj); + +// Compile cache fingerprint for a registered pytree's type+aux pair. +// Combines id(type) and hash(aux) so that compile retraces if either changes. +uint64_t registered_pytree_fingerprint(nb::handle obj); + void tree_visit( const std::vector& trees, std::function&)> visitor); diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 20f1145223..3e4ef16f04 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -1322,6 +1322,67 @@ def fun(x): np.asarray(out, copy=False).__array_interface__["data"][0], in_ptr ) + def test_compile_registered_pytree_node(self): + # Custom container class that holds two arrays. + class Pair: + def __init__(self, a, b): + self.a = a + self.b = b + + # Before registration, compile must reject Pair instances. + before = mx.compile(lambda p: p.a + p.b) + with self.assertRaises(ValueError): + before(Pair(mx.array(1), mx.array(2))) + + # Register Pair and verify the compiled function works end-to-end. + mx.register_pytree_node( + Pair, + lambda p: ([p.a, p.b], None), + lambda _aux, children: Pair(*children), + ) + + @mx.compile + def add_pair(p): + return p.a + p.b + + out = add_pair(Pair(mx.array(3), mx.array(4))) + self.assertEqual(out.item(), 7) + + # Aux-data participates in the cache key: a subclass with different + # aux-tag is treated as a distinct shape. + class TaggedPair(Pair): + pass + + mx.register_pytree_node( + TaggedPair, + lambda p: ([p.a, p.b], "tag-v1"), + lambda _aux, children: TaggedPair(*children), + ) + + @mx.compile + def doubled(p): + return (p.a + p.b) * 2 + + self.assertEqual(doubled(Pair(mx.array(5), mx.array(6))).item(), 22) + self.assertEqual(doubled(TaggedPair(mx.array(5), mx.array(6))).item(), 22) + + # Bad flatten_fn return value surfaces a clean error. + class BadFlatten: + pass + + mx.register_pytree_node( + BadFlatten, + lambda _: "not a pair", + lambda _aux, children: BadFlatten(), + ) + + @mx.compile + def use_bad(_p): + return mx.array(0) + + with self.assertRaises(ValueError): + use_bad(BadFlatten()) + if __name__ == "__main__": mlx_tests.MLXTestRunner()