Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/mlx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from itertools import zip_longest
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from mlx.core import register_pytree_node # noqa: F401 (public re-export)


def tree_map(
fn: Callable, tree: Any, *rest: Any, is_leaf: Optional[Callable] = None
Expand Down
2 changes: 2 additions & 0 deletions python/src/mlx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ void init_metal(nb::module_&);
void init_cuda(nb::module_&);
void init_memory(nb::module_&);
void init_ops(nb::module_&);
void init_trees(nb::module_&);
void init_transforms(nb::module_&);
void init_random(nb::module_&);
void init_fft(nb::module_&);
Expand All @@ -39,6 +40,7 @@ NB_MODULE(core, m) {
init_cuda(m);
init_memory(m);
init_ops(m);
init_trees(m);
init_transforms(m);
init_random(m);
init_fft(m);
Expand Down
15 changes: 14 additions & 1 deletion python/src/transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ struct PyCompiledFun {
constexpr uint64_t list_identifier = 18446744073709551533UL;
constexpr uint64_t dict_identifier = 18446744073709551521UL;
constexpr uint64_t none_identifier = 10239356951478402889UL;
constexpr uint64_t pytree_identifier = 14695981039346656037UL;

// Flatten the tree with hashed constants and structure
std::function<void(nb::handle)> recurse;
Expand Down Expand Up @@ -488,6 +489,17 @@ struct PyCompiledFun {
} else if (nb::isinstance<mx::array>(obj)) {
inputs.push_back(nb::cast<mx::array>(obj));
constants.push_back(array_identifier);
} else if (is_registered_pytree(obj)) {
// Custom registered pytree node — treat as an internal node. The
// type identity + aux fingerprint participate in the compile cache
// key so two structurally-different instances retrace.
constants.push_back(pytree_identifier);
constants.push_back(registered_pytree_fingerprint(obj));
auto children = pytree_children(obj);
constants.push_back(static_cast<uint64_t>(children.size()));
for (const auto& child : children) {
recurse(child);
}
} else if (nb::isinstance<nb::str>(obj)) {
auto r = obj.attr("__hash__")();
constants.push_back(nb::cast<int64_t>(r));
Expand All @@ -502,7 +514,8 @@ struct PyCompiledFun {
std::ostringstream msg;
msg << "[compile] Function arguments must be trees of arrays "
<< "or constants (floats, ints, strings, or None), but received "
<< "type " << type_name_str(obj) << ".";
<< "type " << type_name_str(obj) << ". To pass a custom type, "
<< "register it with mx.register_pytree_node().";
throw std::invalid_argument(msg.str());
}
};
Expand Down
224 changes: 224 additions & 0 deletions python/src/trees.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,159 @@
// Copyright © 2023-2024 Apple Inc.

#include <unordered_map>

#include <nanobind/stl/pair.h>
#include <nanobind/stl/vector.h>

#include "python/src/trees.h"

namespace {

struct PytreeNodeDef {
nb::callable flatten_fn;
nb::callable unflatten_fn;
};

// The map is intentionally heap-allocated and never freed. Holding
// nb::callable references in a function-local static triggers a
// use-after-finalize when the C++ runtime tears down the static during
// interpreter shutdown — the Python state is already gone, so decrefing the
// stored callables segfaults. This is the same lifetime trick used by
// structure_sentinel() below.
std::unordered_map<PyObject*, PytreeNodeDef>& registry() {
static auto* r = new std::unordered_map<PyObject*, PytreeNodeDef>();
return *r;
}

// Calls the registered flatten_fn for obj and returns (children, aux).
std::pair<std::vector<nb::object>, nb::object> flatten_registered(
nb::handle obj) {
auto& def = registry().at(reinterpret_cast<PyObject*>(Py_TYPE(obj.ptr())));
auto seq = nb::cast<nb::sequence>(def.flatten_fn(obj));
if (nb::len(seq) != 2) {
throw std::invalid_argument(
"[register_pytree_node] flatten_fn must return a (children, aux_data) "
"pair.");
}
auto children = nb::cast<std::vector<nb::object>>(seq[0]);
return {std::move(children), nb::cast<nb::object>(seq[1])};
}

// Recreates the original object from aux + children for the type of `like`.
nb::object unflatten_registered(
nb::handle like,
nb::object aux,
const std::vector<nb::object>& children) {
auto& def = registry().at(reinterpret_cast<PyObject*>(Py_TYPE(like.ptr())));
nb::list children_list;
for (const auto& c : children) {
children_list.append(c);
}
return def.unflatten_fn(aux, children_list);
}

} // namespace

void register_pytree_node(
nb::type_object cls,
nb::callable flatten_fn,
nb::callable unflatten_fn) {
registry()[cls.ptr()] = PytreeNodeDef{flatten_fn, unflatten_fn};
}

bool is_registered_pytree(nb::handle obj) {
if (!obj.ptr()) {
return false;
}
return registry().find(reinterpret_cast<PyObject*>(Py_TYPE(obj.ptr()))) !=
registry().end();
}

std::vector<nb::object> pytree_children(nb::handle obj) {
return flatten_registered(obj).first;
}

uint64_t registered_pytree_fingerprint(nb::handle obj) {
PyObject* type = reinterpret_cast<PyObject*>(Py_TYPE(obj.ptr()));
uint64_t fp = reinterpret_cast<uintptr_t>(type);

// Mix in hash(aux_data) so structurally distinct registered nodes don't
// collide. We re-call flatten_fn purely to retrieve aux; this is the same
// cost as the structural recurse below and keeps the fingerprint in sync
// with how the node will be expanded.
auto it = registry().find(type);
if (it != registry().end()) {
try {
auto seq = nb::cast<nb::sequence>(it->second.flatten_fn(obj));
if (nb::len(seq) == 2) {
nb::object aux = nb::cast<nb::object>(seq[1]);
if (!aux.is_none()) {
try {
uint64_t aux_hash = static_cast<uint64_t>(nb::hash(aux));
fp ^= aux_hash + 0x9e3779b97f4a7c15ULL + (fp << 6) + (fp >> 2);
} catch (...) {
// Unhashable aux — fall back to type-only fingerprint.
}
}
}
} catch (...) {
// flatten_fn failed — fall back to type-only fingerprint.
}
}
return fp;
}

void init_trees(nb::module_& m) {
m.def(
"register_pytree_node",
&register_pytree_node,
nb::arg("cls"),
nb::arg("flatten_fn"),
nb::arg("unflatten_fn"),
R"pbdoc(
Register a custom class as a pytree node.

Once registered, instances of ``cls`` are treated as interior nodes
(not leaves) by :func:`mlx.core.compile`, :func:`mlx.utils.tree_map`,
:func:`mlx.utils.tree_flatten`, and friends.

Args:
cls (type): The class to register.
flatten_fn (callable): ``flatten_fn(obj) -> (children, aux_data)``
where *children* is a list or tuple of sub-trees (may contain
:class:`array` or further nested structures) and *aux_data* is
any hashable metadata needed to reconstruct the object. When
a registered object appears as a :func:`compile` argument the
hash of *aux_data* participates in the compile cache key, so
two instances with different aux trigger a retrace.
unflatten_fn (callable): ``unflatten_fn(aux_data, children) -> obj``
that recreates the original object from *aux_data* and the
(possibly updated) *children* list.

Example:

>>> import mlx.core as mx
>>>
>>> class Pair:
... def __init__(self, a, b):
... self.a = a
... self.b = b
...
>>> mx.register_pytree_node(
... Pair,
... lambda p: ([p.a, p.b], None),
... lambda _, children: Pair(*children),
... )
>>>
>>> @mx.compile
... def add_pair(p):
... return p.a + p.b
...
>>> add_pair(Pair(mx.array(1), mx.array(2)))
array(3, dtype=int32)
)pbdoc");
}

template <typename T, typename U, typename V>
void validate_subtrees(const std::vector<nb::object>& subtrees) {
int len = nb::cast<T>(subtrees[0]).size();
Expand Down Expand Up @@ -76,6 +228,39 @@ nb::object tree_map(
d[item.first] = recurse(items);
}
return nb::cast<nb::object>(d);
} else if (is_registered_pytree(subtrees[0])) {
auto [children, aux] = flatten_registered(subtrees[0]);
PyTypeObject* type = Py_TYPE(subtrees[0].ptr());

// Pre-flatten every other subtree so we can index parallel children.
std::vector<std::vector<nb::object>> other_children(subtrees.size());
other_children[0] = std::move(children);
for (size_t j = 1; j < subtrees.size(); ++j) {
if (is_registered_pytree(subtrees[j]) &&
Py_TYPE(subtrees[j].ptr()) == type) {
other_children[j] = flatten_registered(subtrees[j]).first;
if (other_children[j].size() != other_children[0].size()) {
throw std::invalid_argument(
"[tree_map] Additional input tree is not a valid prefix of "
"the first tree.");
}
}
}

std::vector<nb::object> new_children;
new_children.reserve(other_children[0].size());
for (size_t i = 0; i < other_children[0].size(); ++i) {
std::vector<nb::object> items(subtrees.size());
for (size_t j = 0; j < subtrees.size(); ++j) {
if (!other_children[j].empty()) {
items[j] = other_children[j][i];
} else {
items[j] = subtrees[j];
}
}
new_children.push_back(recurse(items));
}
return unflatten_registered(subtrees[0], aux, new_children);
} else {
return transform(subtrees);
}
Expand Down Expand Up @@ -143,6 +328,32 @@ void tree_visit(
}
recurse(items);
}
} else if (is_registered_pytree(subtrees[0])) {
PyTypeObject* type = Py_TYPE(subtrees[0].ptr());
std::vector<std::vector<nb::object>> other_children(subtrees.size());
other_children[0] = flatten_registered(subtrees[0]).first;
for (size_t j = 1; j < subtrees.size(); ++j) {
if (is_registered_pytree(subtrees[j]) &&
Py_TYPE(subtrees[j].ptr()) == type) {
other_children[j] = flatten_registered(subtrees[j]).first;
if (other_children[j].size() != other_children[0].size()) {
throw std::invalid_argument(
"[tree_visit] Additional input tree is not a valid prefix of "
"the first tree.");
}
}
}
for (size_t i = 0; i < other_children[0].size(); ++i) {
std::vector<nb::object> items(subtrees.size());
for (size_t j = 0; j < subtrees.size(); ++j) {
if (!other_children[j].empty()) {
items[j] = other_children[j][i];
} else {
items[j] = subtrees[j];
}
}
recurse(items);
}
} else {
visitor(subtrees);
}
Expand All @@ -162,6 +373,11 @@ void tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor) {
for (auto item : nb::cast<nb::dict>(subtree)) {
recurse(item.second);
}
} else if (is_registered_pytree(subtree)) {
auto [children, _] = flatten_registered(subtree);
for (const auto& child : children) {
recurse(child);
}
} else {
visitor(subtree);
}
Expand Down Expand Up @@ -197,6 +413,14 @@ void tree_visit_update(
d[item.first] = recurse(item.second);
}
return nb::cast<nb::object>(d);
} else if (is_registered_pytree(subtree)) {
auto [children, aux] = flatten_registered(subtree);
std::vector<nb::object> new_children;
new_children.reserve(children.size());
for (auto& c : children) {
new_children.push_back(recurse(c));
}
return unflatten_registered(subtree, aux, new_children);
} else if (nb::isinstance<mx::array>(subtree)) {
return visitor(subtree);
} else {
Expand Down
28 changes: 28 additions & 0 deletions python/src/trees.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,40 @@
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <nanobind/nanobind.h>
#include <vector>

#include "mlx/array.h"

namespace mx = mlx::core;
namespace nb = nanobind;

// --------------------------------------------------------------------------
// Pytree node registry
//
// Allows third-party Python classes to participate in MLX tree utilities
// and in mx.compile argument flattening. Mirrors the API of
// jax.tree_util.register_pytree_node:
//
// flatten_fn(obj) -> (children: Sequence, aux_data: Any)
// unflatten_fn(aux, children) -> obj
// --------------------------------------------------------------------------

void register_pytree_node(
nb::type_object cls,
nb::callable flatten_fn,
nb::callable unflatten_fn);

// True if type(obj) has been registered as a pytree node.
bool is_registered_pytree(nb::handle obj);

// Returns the children list for a registered pytree node. Caller must ensure
// is_registered_pytree(obj) is true.
std::vector<nb::object> pytree_children(nb::handle obj);

// Compile cache fingerprint for a registered pytree's type+aux pair.
// Combines id(type) and hash(aux) so that compile retraces if either changes.
uint64_t registered_pytree_fingerprint(nb::handle obj);

void tree_visit(
const std::vector<nb::object>& trees,
std::function<void(const std::vector<nb::object>&)> visitor);
Expand Down
Loading