```python import jax.numpy as jnp import tree_math as tm def f(x, y): return x, y x = y = tm.Vector(jnp.array(0.)) tm.unwrap(f, out_vectors = (True, False))(x, y) # (tree_math.Vector(DeviceArray(0., dtype=float32, weak_type=True)), DeviceArray(0., dtype=float32, weak_type=True)) tm.unwrap(f, out_vectors = [True, False])(x, y) # ValueError: Expected list, got (DeviceArray(0., dtype=float32, weak_type=True), DeviceArray(0., dtype=float32, weak_type=True)). ```