From 76f335060fe0219a3f7bfa31295a4a37b41e06b3 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Sun, 10 May 2026 16:01:38 +0100 Subject: [PATCH 1/4] Support Metal DLPack zero-copy import --- docs/src/usage/numpy.rst | 49 +++++- mlx/allocator.h | 4 + mlx/backend/cuda/allocator.cpp | 4 + mlx/backend/metal/allocator.cpp | 18 ++- mlx/backend/no_gpu/allocator.cpp | 4 + mlx/ops.cpp | 10 ++ mlx/ops.h | 3 + mlx/utils.cpp | 16 +- python/src/buffer.h | 8 + python/src/convert.cpp | 257 +++++++++++++++++++++---------- python/src/convert.h | 2 + python/src/mlx_func.cpp | 2 + python/tests/test_array.py | 85 +++++++++- 13 files changed, 369 insertions(+), 93 deletions(-) diff --git a/docs/src/usage/numpy.rst b/docs/src/usage/numpy.rst index bf71938dff..e455fbebda 100644 --- a/docs/src/usage/numpy.rst +++ b/docs/src/usage/numpy.rst @@ -76,6 +76,7 @@ PyTorch ------- PyTorch supports DLPack inputs and can import MLX arrays directly. +MLX can also import PyTorch tensors through DLPack with ``mx.array``. .. code-block:: python @@ -84,7 +85,53 @@ PyTorch supports DLPack inputs and can import MLX arrays directly. a = mx.arange(3) b = torch.tensor(a) - c = mx.array(b.cpu()) + c = mx.array(b) + +Creating an MLX array from a CPU tensor copies the data into MLX-owned storage. +The arrays do not share memory: + +.. code-block:: python + + b = torch.arange(3) + c = mx.array(b) + + b += 10 + print(c.tolist()) # [0, 1, 2] + +Metal DLPack inputs are different. If a PyTorch MPS tensor is passed to +``mx.array``, MLX imports the underlying Metal buffer without copying it. The +PyTorch tensor and the MLX array then share the same storage. + +Since the buffer is shared across frameworks, synchronization has to be managed +explicitly. After PyTorch writes to an MPS tensor, call +``torch.mps.synchronize()`` before reading the shared data from MLX. After MLX +writes to the shared array, call ``mx.eval`` on the MLX result before reading +the shared data from PyTorch. Without these synchronization points, the other +framework may read the shared buffer before the producer has finished writing, +so it can observe stale data. + +.. code-block:: python + + b = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + c = mx.array(b) # zero-copy Metal DLPack import + + b.add_(10) + torch.mps.synchronize() + print(c.tolist()) # [10.0, 11.0, 12.0] + +Updates made by MLX can also be observed from PyTorch after the MLX computation +has been evaluated: + +.. code-block:: python + + b = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + c = mx.array(b) + + c += 10 + mx.eval(c) + print(b.cpu()) # tensor([10., 11., 12.]) JAX --- diff --git a/mlx/allocator.h b/mlx/allocator.h index 824deac2c7..185ccd6d76 100644 --- a/mlx/allocator.h +++ b/mlx/allocator.h @@ -21,6 +21,10 @@ class MLX_API Buffer { // Get the raw data pointer from the buffer void* raw_ptr(); + // Whether raw_ptr() can return a host-accessible pointer without moving or + // copying the buffer. + bool is_host_accessible() const; + // Get the buffer pointer from the buffer const void* ptr() const { return ptr_; diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 718ae33e9c..37dc457a02 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -416,6 +416,10 @@ void* Buffer::raw_ptr() { return cbuf.data; } +bool Buffer::is_host_accessible() const { + return true; +} + } // namespace allocator size_t get_active_memory() { diff --git a/mlx/backend/metal/allocator.cpp b/mlx/backend/metal/allocator.cpp index 222c6fd9fa..5b24142fd5 100644 --- a/mlx/backend/metal/allocator.cpp +++ b/mlx/backend/metal/allocator.cpp @@ -24,7 +24,23 @@ void* Buffer::raw_ptr() { if (!ptr_) { return nullptr; } - return static_cast(ptr_)->contents(); + auto* buf = static_cast(ptr_); + auto* contents = buf->contents(); + if (!contents && buf->length() > 0) { + throw std::runtime_error( + "[metal::Buffer::raw_ptr] Cannot access Metal buffer contents on the " + "host. The buffer is not CPU-addressable, for example because it uses " + "private storage."); + } + return contents; +} + +bool Buffer::is_host_accessible() const { + if (!ptr_) { + return true; + } + auto* buf = static_cast(ptr_); + return buf->storageMode() != MTL::StorageModePrivate; } } // namespace allocator diff --git a/mlx/backend/no_gpu/allocator.cpp b/mlx/backend/no_gpu/allocator.cpp index abb83e50e4..88b5e14441 100644 --- a/mlx/backend/no_gpu/allocator.cpp +++ b/mlx/backend/no_gpu/allocator.cpp @@ -76,6 +76,10 @@ void* Buffer::raw_ptr() { return static_cast(ptr_) + 1; } +bool Buffer::is_host_accessible() const { + return true; +} + Buffer CommonAllocator::malloc(size_t size) { void* ptr = std::malloc(size + sizeof(size_t)); if (ptr != nullptr) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index defcc2f6e0..f765b67aa9 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -294,6 +294,16 @@ array copy(array a, StreamOrDevice s /* = {} */) { {std::move(a)}); } +array copy_to_new_buffer(array a, StreamOrDevice s /* = {} */) { + auto copied_shape = a.shape(); // |a| will be moved + auto dtype = a.dtype(); + return array( + std::move(copied_shape), + dtype, + std::make_shared(to_stream(s), dtype), + {std::move(a)}); +} + array full_impl(array vals, Dtype dtype, StreamOrDevice s /* = {} */) { return array( vals.shape(), diff --git a/mlx/ops.h b/mlx/ops.h index 208964d1aa..3084de3962 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -60,6 +60,9 @@ MLX_API array as_strided( /** Copy another array. */ MLX_API array copy(array a, StreamOrDevice s = {}); +/** Copy another array into newly allocated storage. */ +MLX_API array copy_to_new_buffer(array a, StreamOrDevice s = {}); + /** Fill an array of the given shape with the given value(s). */ MLX_API array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s = {}); MLX_API array full(Shape shape, array vals, StreamOrDevice s = {}); diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 239e6603dd..39056ee79e 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -7,6 +7,7 @@ #include #include "mlx/dtype_utils.h" +#include "mlx/ops.h" #include "mlx/types/limits.h" #include "mlx/utils.h" @@ -212,6 +213,19 @@ std::ostream& operator<<(std::ostream& os, uint8_t x) { namespace { +array host_accessible_array(array a) { + a.eval(); + a.wait(); + if (a.buffer().is_host_accessible()) { + return a; + } + auto out = copy_to_new_buffer(std::move(a), Device::gpu); + out.eval(); + out.wait(); + out.detach(); + return out; +} + template void print_subarray(std::ostream& os, const array& a, size_t index, int dim) { int num_print = 3; @@ -277,7 +291,7 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) { } std::ostream& operator<<(std::ostream& os, array a) { - a.eval(); + a = host_accessible_array(std::move(a)); dispatch_all_types(a.dtype(), [&](auto type_tag) { print_array(os, a); }); diff --git a/python/src/buffer.h b/python/src/buffer.h index 272a918883..8d01b82132 100644 --- a/python/src/buffer.h +++ b/python/src/buffer.h @@ -91,6 +91,14 @@ extern "C" inline int getbuffer(PyObject* obj, Py_buffer* view, int flags) { { nb::gil_scoped_release nogil; a.eval(); + a.wait(); + } + if (!a.buffer().is_host_accessible()) { + PyErr_SetString( + PyExc_BufferError, + "Cannot provide a buffer for an array whose storage is not " + "CPU-addressable."); + return -1; } std::vector shape(a.shape().begin(), a.shape().end()); diff --git a/python/src/convert.cpp b/python/src/convert.cpp index a5455c2b33..0b78a8fe20 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -1,13 +1,16 @@ // Copyright © 2024 Apple Inc. #include +#include #include #include +#include #include "python/src/convert.h" #include "python/src/utils.h" +#include "mlx/ops.h" #include "mlx/utils.h" enum PyScalarT { @@ -31,6 +34,16 @@ int check_shape_dim(int64_t dim) { return static_cast(dim); } +template +mx::Shape get_shape(const nb::ndarray& nd_array) { + mx::Shape shape; + shape.reserve(nd_array.ndim()); + for (int i = 0; i < nd_array.ndim(); i++) { + shape.push_back(check_shape_dim(nd_array.shape(i))); + } + return shape; +} + template mx::array nd_array_to_mlx_contiguous( nb::ndarray nd_array, @@ -42,88 +55,149 @@ mx::array nd_array_to_mlx_contiguous( return mx::array(static_cast(data_ptr), shape, dtype); } -mx::array nd_array_to_mlx( - nb::ndarray nd_array, - std::optional dtype, - std::optional nb_dtype) { - if (nd_array.device_type() != nb::device::cpu::value) { - throw std::invalid_argument( - "Cannot convert non-CPU DLPack array to mlx array."); - } - - // Compute the shape and size - mx::Shape shape; - shape.reserve(nd_array.ndim()); - for (int i = 0; i < nd_array.ndim(); i++) { - shape.push_back(check_shape_dim(nd_array.shape(i))); - } - auto type = nb_dtype.value_or(nd_array.dtype()); - - // Copy data and make array +template +auto dispatch_dlpack_dtype( + nb::dlpack::dtype type, + F&& f, + const char* error_message) { if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::bool_)); + return f.template operator()(mx::bool_); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::uint8)); + return f.template operator()(mx::uint8); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::uint16)); + return f.template operator()(mx::uint16); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::uint32)); + return f.template operator()(mx::uint32); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::uint64)); + return f.template operator()(mx::uint64); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::int8)); + return f.template operator()(mx::int8); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::int16)); + return f.template operator()(mx::int16); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::int32)); + return f.template operator()(mx::int32); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::int64)); + return f.template operator()(mx::int64); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::float16)); + return f.template operator()(mx::float16); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::bfloat16)); + return f.template operator()(mx::bfloat16); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::float32)); + return f.template operator()(mx::float32); } else if (type == nb::dtype()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::float32)); + return f.template operator()(mx::float32); } else if (type == nb::dtype>()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::complex64)); + return f.template operator()(mx::complex64); } else if (type == nb::dtype>()) { - return nd_array_to_mlx_contiguous( - nd_array, shape, dtype.value_or(mx::complex64)); + return f.template operator()(mx::complex64); } else { - throw std::invalid_argument("Cannot convert numpy array to mlx array."); + throw std::invalid_argument(error_message); } } +mx::array metal_dlpack_to_mlx( + nb::ndarray nd_array, + std::optional dtype); + +mx::array host_accessible_array(mx::array a) { + a.eval(); + a.wait(); + if (a.buffer().is_host_accessible()) { + return a; + } + auto out = mx::copy_to_new_buffer(std::move(a), mx::Device::gpu); + out.eval(); + out.wait(); + out.detach(); + return out; +} + +mx::array nd_array_to_mlx( + nb::ndarray nd_array, + std::optional dtype, + std::optional nb_dtype) { + switch (nd_array.device_type()) { + case nb::device::cpu::value: { + auto shape = get_shape(nd_array); + auto type = nb_dtype.value_or(nd_array.dtype()); + return dispatch_dlpack_dtype( + type, + [&](mx::Dtype default_dtype) { + return nd_array_to_mlx_contiguous( + nd_array, shape, dtype.value_or(default_dtype)); + }, + "Cannot convert numpy array to mlx array."); + } + case nb::device::metal::value: + return metal_dlpack_to_mlx(std::move(nd_array), dtype); + default: + throw std::invalid_argument("Unsupported DLPack device."); + } +} + +template +mx::array metal_dlpack_to_mlx_contiguous( + std::shared_ptr> owner, + const mx::Shape& shape, + mx::Dtype type, + std::optional dtype) { + auto itemsize = mx::size_of(type); + if (owner->itemsize() != itemsize) { + throw std::invalid_argument( + "Cannot convert Metal DLPack dtype to mlx dtype."); + } + + auto byte_offset = owner->data_offset(); + if (byte_offset % itemsize != 0) { + throw std::invalid_argument( + "Metal DLPack byte offset is not aligned to dtype size."); + } + + auto out = mx::array( + mx::allocator::Buffer(owner->data_handle()), + shape, + type, + [](mx::allocator::Buffer) {}); + auto flags = out.flags(); + out.set_data( + out.buffer(), + out.data_size(), + out.strides(), + flags, + [owner = std::move(owner)](mx::allocator::Buffer) {}); + + auto offset = static_cast(byte_offset / itemsize); + if (offset != 0) { + out.copy_shared_buffer(out, out.strides(), flags, out.data_size(), offset); + } + + if (dtype) { + auto result = (*dtype == out.dtype()) + ? mx::copy_to_new_buffer(out, mx::Device::gpu) + : mx::astype(out, *dtype, mx::Device::gpu); + result.eval(); + result.wait(); + result.detach(); + return result; + } + return out; +} + template nb::ndarray mlx_to_nd_array_impl( mx::array a, std::optional t = {}) { { nb::gil_scoped_release nogil; - a.eval(); + a = host_accessible_array(std::move(a)); } std::vector shape(a.shape().begin(), a.shape().end()); + auto owner = nb::cast(a); return nb::ndarray( a.data(), a.ndim(), shape.data(), - /* owner= */ nb::none(), + /* owner= */ owner, a.strides().data(), t.value_or(nb::dtype())); } @@ -177,44 +251,60 @@ nb::object to_scalar(mx::array& a) { throw std::invalid_argument( "[convert] Only length-1 arrays can be converted to Python scalars."); } + auto host = mx::array(a); { nb::gil_scoped_release nogil; - a.eval(); + host = host_accessible_array(std::move(host)); } - switch (a.dtype()) { + switch (host.dtype()) { case mx::bool_: - return nb::cast(a.item()); + return nb::cast(host.item()); case mx::uint8: - return nb::cast(a.item()); + return nb::cast(host.item()); case mx::uint16: - return nb::cast(a.item()); + return nb::cast(host.item()); case mx::uint32: - return nb::cast(a.item()); + return nb::cast(host.item()); case mx::uint64: - return nb::cast(a.item()); + return nb::cast(host.item()); case mx::int8: - return nb::cast(a.item()); + return nb::cast(host.item()); case mx::int16: - return nb::cast(a.item()); + return nb::cast(host.item()); case mx::int32: - return nb::cast(a.item()); + return nb::cast(host.item()); case mx::int64: - return nb::cast(a.item()); + return nb::cast(host.item()); case mx::float16: - return nb::cast(static_cast(a.item())); + return nb::cast(static_cast(host.item())); case mx::float32: - return nb::cast(a.item()); + return nb::cast(host.item()); case mx::bfloat16: - return nb::cast(static_cast(a.item())); + return nb::cast(static_cast(host.item())); case mx::complex64: - return nb::cast(a.item>()); + return nb::cast(host.item>()); case mx::float64: - return nb::cast(a.item()); + return nb::cast(host.item()); default: throw nb::type_error("type cannot be converted to Python scalar."); } } +mx::array metal_dlpack_to_mlx( + nb::ndarray nd_array, + std::optional dtype) { + auto owner = + std::make_shared>(std::move(nd_array)); + auto shape = get_shape(*owner); + + return dispatch_dlpack_dtype( + owner->dtype(), + [&](mx::Dtype type) { + return metal_dlpack_to_mlx_contiguous(owner, shape, type, dtype); + }, + "Cannot convert Metal DLPack array to mlx array."); +} + template nb::list to_list(mx::array& a, size_t index, int dim) { nb::list pl; @@ -234,39 +324,40 @@ nb::object tolist(mx::array& a) { if (a.ndim() == 0) { return to_scalar(a); } + auto host = mx::array(a); { nb::gil_scoped_release nogil; - a.eval(); + host = host_accessible_array(std::move(host)); } - switch (a.dtype()) { + switch (host.dtype()) { case mx::bool_: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::uint8: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::uint16: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::uint32: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::uint64: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::int8: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::int16: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::int32: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::int64: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::float16: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::float32: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::bfloat16: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::float64: - return to_list(a, 0, 0); + return to_list(host, 0, 0); case mx::complex64: - return to_list>(a, 0, 0); + return to_list>(host, 0, 0); default: throw nb::type_error("data type cannot be converted to Python list."); } diff --git a/python/src/convert.h b/python/src/convert.h index 9341dd3122..a8e56d64f1 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -69,6 +69,8 @@ mx::array nd_array_to_mlx( nb::ndarray mlx_to_np_array(const mx::array& a); nb::ndarray<> mlx_to_dlpack(const mx::array& a); +mx::array host_accessible_array(mx::array a); + nb::object to_scalar(mx::array& a); nb::object tolist(mx::array& a); diff --git a/python/src/mlx_func.cpp b/python/src/mlx_func.cpp index 09aceabe9b..9955e134f9 100644 --- a/python/src/mlx_func.cpp +++ b/python/src/mlx_func.cpp @@ -2,6 +2,8 @@ #include "python/src/mlx_func.h" +#include + // A garbage collected function which wraps nb::cpp_function // See https://github.com/wjakob/nanobind/discussions/919 diff --git a/python/tests/test_array.py b/python/tests/test_array.py index ae775323b0..6c4a1ded21 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -735,7 +735,7 @@ def test_array_np_conversion(self): self.assertEqual(x.tolist(), cvals) def test_array_np_dtype_conversion(self): - dtypes_list = [ + to_mlx_dtypes_list = [ (mx.bool_, np.bool_), (mx.uint8, np.uint8), (mx.uint16, np.uint16), @@ -750,13 +750,14 @@ def test_array_np_dtype_conversion(self): (mx.complex64, np.complex64), ] - for mlx_dtype, np_dtype in dtypes_list: + for mlx_dtype, np_dtype in to_mlx_dtypes_list: a_npy = np.random.uniform(low=0, high=100, size=(32,)).astype(np_dtype) a_mlx = mx.array(a_npy) self.assertEqual(a_mlx.dtype, mlx_dtype) self.assertTrue(np.allclose(a_mlx, a_npy)) + for mlx_dtype, np_dtype in to_mlx_dtypes_list: b_mlx = mx.random.uniform( low=0, high=10, @@ -2048,19 +2049,89 @@ def test_dlpack(self): self.assertTrue(mx.array_equal(y, x)) @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") - def test_torch_mps_dlpack_non_cpu_error(self): + def test_torch_mps_dlpack_import(self): x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) self.assertEqual(x.__dlpack_device__()[0], 8) - with self.assertRaisesRegex(ValueError, "non-CPU DLPack"): - mx.array(x) + y = mx.array(x) + self.assertEqual(y.dtype, mx.float32) + torch.mps.synchronize() + self.assertEqual(y.tolist(), x.cpu().numpy().tolist()) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_host_access(self): + x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) + y = mx.array(x) + + torch.mps.synchronize() + self.assertIn("array(", repr(y)) + self.assertEqual(y.tolist(), x.cpu().numpy().tolist()) + with self.assertRaises(BufferError): + memoryview(y) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_zero_copy_reads_torch_updates(self): + x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) + y = mx.array(x) + + x.add_(100) + torch.mps.synchronize() + self.assertEqual((y + 1).tolist(), (x + 1).cpu().numpy().tolist()) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_dtype_argument_copies(self): + x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) + torch.mps.synchronize() + y_copy = mx.array(x, dtype=mx.float32) + expected = x.cpu().numpy().tolist() + + x.add_(100) + torch.mps.synchronize() + self.assertEqual(y_copy.tolist(), expected) + + z = mx.array(x, dtype=mx.float16) + self.assertEqual(z.dtype, mx.float16) + self.assertEqual(z.tolist(), x.to(torch.float16).cpu().numpy().tolist()) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_data_offset(self): + view = torch.arange(12, device="mps", dtype=torch.float32)[3:9] + view_mx = mx.array(view) + torch.mps.synchronize() + self.assertEqual((view_mx + 1).tolist(), (view + 1).cpu().numpy().tolist()) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_bfloat16(self): + x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) + bf = x.to(torch.bfloat16) + bf_mx = mx.array(bf) + self.assertEqual(bf_mx.dtype, mx.bfloat16) + torch.mps.synchronize() + self.assertEqual( + bf_mx.astype(mx.float32).tolist(), + bf.to(torch.float32).cpu().numpy().tolist(), + ) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_array_operand(self): a = mx.array([1]) b = torch.tensor([2]) self.assertTrue(mx.array_equal(a + b, mx.array([3]))) - with self.assertRaisesRegex(ValueError, "non-CPU DLPack"): - a + b.to("mps") + b_mps = b.to("mps") + torch.mps.synchronize() + self.assertTrue(mx.array_equal(a + b_mps, mx.array([3]))) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_mlx_update_writes_torch_buffer(self): + x = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + y = mx.array(x) + + y += 3 + mx.eval(y) + self.assertEqual(x.cpu().numpy().tolist(), [3, 4, 5]) def test_getitem_with_list(self): a = mx.array([1, 2, 3, 4, 5]) From da1c8de785b3824ee9eac478b5865fd87217e586 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Sun, 10 May 2026 16:02:38 +0100 Subject: [PATCH 2/4] Add from_dlpack copy controls --- docs/src/usage/numpy.rst | 22 ++++++++++++++++--- python/src/convert.cpp | 29 +++++++++++++++++++++++++ python/src/convert.h | 1 + python/src/ops.cpp | 23 ++++++++++++++++++++ python/tests/test_array.py | 44 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 116 insertions(+), 3 deletions(-) diff --git a/docs/src/usage/numpy.rst b/docs/src/usage/numpy.rst index e455fbebda..f5641b4be5 100644 --- a/docs/src/usage/numpy.rst +++ b/docs/src/usage/numpy.rst @@ -76,7 +76,8 @@ PyTorch ------- PyTorch supports DLPack inputs and can import MLX arrays directly. -MLX can also import PyTorch tensors through DLPack with ``mx.array``. +MLX can also import PyTorch tensors through DLPack with ``mx.array`` or +``mx.from_dlpack``. .. code-block:: python @@ -99,8 +100,9 @@ The arrays do not share memory: print(c.tolist()) # [0, 1, 2] Metal DLPack inputs are different. If a PyTorch MPS tensor is passed to -``mx.array``, MLX imports the underlying Metal buffer without copying it. The -PyTorch tensor and the MLX array then share the same storage. +``mx.array`` or to ``mx.from_dlpack`` with ``copy=None`` or ``copy=False``, MLX +imports the underlying Metal buffer without copying it. The PyTorch tensor and +the MLX array then share the same storage. Since the buffer is shared across frameworks, synchronization has to be managed explicitly. After PyTorch writes to an MPS tensor, call @@ -133,6 +135,20 @@ has been evaluated: mx.eval(c) print(b.cpu()) # tensor([10., 11., 12.]) +Use ``mx.from_dlpack`` when you need to control the copy behavior. Specifying +``copy=True`` asks MLX to create a new array instead of sharing the Metal +buffer: + +.. code-block:: python + + b = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + c = mx.from_dlpack(b, copy=True) + + b.add_(10) + torch.mps.synchronize() + print(c.tolist()) # [0.0, 1.0, 2.0] + JAX --- JAX fully supports the buffer protocol. diff --git a/python/src/convert.cpp b/python/src/convert.cpp index 0b78a8fe20..a19f1bf745 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -135,6 +135,35 @@ mx::array nd_array_to_mlx( } } +mx::array from_dlpack(nb::object v, std::optional copy) { + using ContigArray = nb::ndarray; + auto nd = nb::cast(v); + + switch (nd.device_type()) { + case nb::device::cpu::value: + if (copy == false) { + throw std::invalid_argument( + "Cannot import a CPU DLPack array without a copy."); + } + return nd_array_to_mlx(std::move(nd), std::nullopt); + case nb::device::metal::value: { + std::optional dtype; + if (copy == true) { + dtype = dispatch_dlpack_dtype( + nd.dtype(), + [](mx::Dtype dtype) { return dtype; }, + "Cannot convert Metal DLPack array to mlx array."); + } + return nd_array_to_mlx(std::move(nd), dtype); + } + case nb::device::cuda::value: + case nb::device::cuda_managed::value: + throw std::invalid_argument("CUDA DLPack import is not supported."); + default: + throw std::invalid_argument("Unsupported DLPack device."); + } +} + template mx::array metal_dlpack_to_mlx_contiguous( std::shared_ptr> owner, diff --git a/python/src/convert.h b/python/src/convert.h index a8e56d64f1..3ac6cb9f74 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -76,6 +76,7 @@ nb::object to_scalar(mx::array& a); nb::object tolist(mx::array& a); mx::array create_array(nb::object v, std::optional t); +mx::array from_dlpack(nb::object v, std::optional copy); mx::array array_from_list(nb::list pl, std::optional dtype); mx::array array_from_list(nb::tuple pl, std::optional dtype); diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 9a48b37afe..1356813d44 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1783,6 +1783,29 @@ void init_ops(nb::module_& m) { Returns: array: An array interpretation of the input. )pbdoc"); + m.def( + "from_dlpack", + [](const nb::object& x, std::optional copy) { + return from_dlpack(x, copy); + }, + nb::arg(), + nb::kw_only(), + "copy"_a = nb::none(), + nb::sig( + "def from_dlpack(x: DLPackCompatible, /, *, copy: Optional[bool] = None) -> array"), + R"pbdoc( + Create an array from an object that supports DLPack. + + Args: + x: Input object implementing ``__dlpack__`` and + ``__dlpack_device__``. + copy (bool, optional): Whether to copy the input. If ``True``, + always copy. If ``False``, never copy. If ``None``, share memory + when possible and copy otherwise. + + Returns: + array: An array containing the input data. + )pbdoc"); m.def( "zeros_like", &mx::zeros_like, diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 6c4a1ded21..c2acb88bcf 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2048,6 +2048,20 @@ def test_dlpack(self): y = np.from_dlpack(x) self.assertTrue(mx.array_equal(y, x)) + def test_from_dlpack_cpu(self): + x = np.arange(3, dtype=np.float32) + + y = mx.from_dlpack(x) + x += 10 + self.assertEqual(y.tolist(), [0.0, 1.0, 2.0]) + + y = mx.from_dlpack(x, copy=True) + x += 10 + self.assertEqual(y.tolist(), [10.0, 11.0, 12.0]) + + with self.assertRaises(ValueError): + mx.from_dlpack(x, copy=False) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_torch_mps_dlpack_import(self): x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) @@ -2133,6 +2147,36 @@ def test_torch_mps_dlpack_mlx_update_writes_torch_buffer(self): mx.eval(y) self.assertEqual(x.cpu().numpy().tolist(), [3, 4, 5]) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_from_dlpack_torch_mps_copy_none_shares(self): + x = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + y = mx.from_dlpack(x) + + x.add_(10) + torch.mps.synchronize() + self.assertEqual(y.tolist(), [10.0, 11.0, 12.0]) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_from_dlpack_torch_mps_copy_false_shares(self): + x = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + y = mx.from_dlpack(x, copy=False) + + y += 10 + mx.eval(y) + self.assertEqual(x.cpu().numpy().tolist(), [10.0, 11.0, 12.0]) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_from_dlpack_torch_mps_copy_true_copies(self): + x = torch.arange(3, device="mps", dtype=torch.float32) + torch.mps.synchronize() + y = mx.from_dlpack(x, copy=True) + + x.add_(10) + torch.mps.synchronize() + self.assertEqual(y.tolist(), [0.0, 1.0, 2.0]) + def test_getitem_with_list(self): a = mx.array([1, 2, 3, 4, 5]) idx = [0, 2, 4] From 3d03aca9d99663b554d9e76f1b10b79458834897 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Tue, 12 May 2026 01:11:23 +0800 Subject: [PATCH 3/4] Support Metal DLPack zero-copy sharing --- docs/src/usage/numpy.rst | 7 ++- mlx/utils.cpp | 4 +- mlx/utils.h | 2 + python/src/array.cpp | 24 ++++++++- python/src/convert.cpp | 105 +++++++++++++++++++++++++++++++------ python/src/convert.h | 7 ++- python/tests/test_array.py | 73 +++++++++++++++++++++++--- 7 files changed, 192 insertions(+), 30 deletions(-) diff --git a/docs/src/usage/numpy.rst b/docs/src/usage/numpy.rst index f5641b4be5..fc511b8098 100644 --- a/docs/src/usage/numpy.rst +++ b/docs/src/usage/numpy.rst @@ -102,7 +102,8 @@ The arrays do not share memory: Metal DLPack inputs are different. If a PyTorch MPS tensor is passed to ``mx.array`` or to ``mx.from_dlpack`` with ``copy=None`` or ``copy=False``, MLX imports the underlying Metal buffer without copying it. The PyTorch tensor and -the MLX array then share the same storage. +the MLX array then share the same storage. MLX arrays exported to PyTorch with +DLPack are also shared without a copy. Since the buffer is shared across frameworks, synchronization has to be managed explicitly. After PyTorch writes to an MPS tensor, call @@ -135,6 +136,10 @@ has been evaluated: mx.eval(c) print(b.cpu()) # tensor([10., 11., 12.]) +For MLX arrays exported to PyTorch, the share is tied to the exported buffer. +MLX updates after export may rebind the MLX array to a new buffer, while the +PyTorch tensor continues to reference the exported buffer. + Use ``mx.from_dlpack`` when you need to control the copy behavior. Specifying ``copy=True`` asks MLX to create a new array instead of sharing the Metal buffer: diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 39056ee79e..ce5b8f3387 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -211,8 +211,6 @@ std::ostream& operator<<(std::ostream& os, uint8_t x) { return os; } -namespace { - array host_accessible_array(array a) { a.eval(); a.wait(); @@ -226,6 +224,8 @@ array host_accessible_array(array a) { return out; } +namespace { + template void print_subarray(std::ostream& os, const array& a, size_t index, int dim) { int num_print = 3; diff --git a/mlx/utils.h b/mlx/utils.h index 7835a97028..486d2638bc 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -65,6 +65,8 @@ MLX_API void set_printoptions(PrintOptions options); MLX_API PrintFormatter& get_global_formatter(); +MLX_API array host_accessible_array(array a); + /** Print the exception and then abort. */ MLX_API void abort_with_exception(const std::exception& error); diff --git a/python/src/array.cpp b/python/src/array.cpp index 28c12f622c..013267b28e 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -496,7 +496,29 @@ void init_array(nb::module_& m) { new (&arr) mx::array(nd_array_to_mlx(nd, std::nullopt)); } }) - .def("__dlpack__", [](const mx::array& a) { return mlx_to_dlpack(a); }) + .def( + "__dlpack__", + [](const mx::array& a, + nb::object, + nb::object, + nb::object dl_device, + nb::object) { + std::optional dl_device_type; + if (!dl_device.is_none()) { + auto device = nb::cast(dl_device); + if (nb::len(device) != 2) { + throw nb::type_error( + "dl_device must be None or a tuple[int, int]"); + } + dl_device_type = nb::cast(device[0]); + } + return mlx_to_dlpack(a, dl_device_type); + }, + nb::kw_only(), + "stream"_a = nb::none(), + "max_version"_a = nb::none(), + "dl_device"_a = nb::none(), + "copy"_a = nb::none()) .def( "__dlpack_device__", [](const mx::array& a) { diff --git a/python/src/convert.cpp b/python/src/convert.cpp index a19f1bf745..d746eab34c 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -10,6 +10,8 @@ #include "python/src/convert.h" #include "python/src/utils.h" +#include "mlx/backend/cuda/cuda.h" +#include "mlx/backend/metal/metal.h" #include "mlx/ops.h" #include "mlx/utils.h" @@ -99,19 +101,6 @@ mx::array metal_dlpack_to_mlx( nb::ndarray nd_array, std::optional dtype); -mx::array host_accessible_array(mx::array a) { - a.eval(); - a.wait(); - if (a.buffer().is_host_accessible()) { - return a; - } - auto out = mx::copy_to_new_buffer(std::move(a), mx::Device::gpu); - out.eval(); - out.wait(); - out.detach(); - return out; -} - mx::array nd_array_to_mlx( nb::ndarray nd_array, std::optional dtype, @@ -176,7 +165,7 @@ mx::array metal_dlpack_to_mlx_contiguous( "Cannot convert Metal DLPack dtype to mlx dtype."); } - auto byte_offset = owner->data_offset(); + auto byte_offset = owner->byte_offset(); if (byte_offset % itemsize != 0) { throw std::invalid_argument( "Metal DLPack byte offset is not aligned to dtype size."); @@ -271,8 +260,92 @@ nb::ndarray mlx_to_np_array(const mx::array& a) { return mlx_to_nd_array(a); } -nb::ndarray<> mlx_to_dlpack(const mx::array& a) { - return mlx_to_nd_array<>(a); +template +nb::ndarray<> mlx_to_dlpack_impl(mx::array a, int dl_device_type) { + void* data = nullptr; + uint64_t byte_offset = 0; + { + nb::gil_scoped_release nogil; + a.eval(); + a.wait(); + if (dl_device_type == nb::device::cpu::value) { + a = host_accessible_array(std::move(a)); + data = a.data(); + } else { + data = a.buffer().ptr(); + byte_offset = a.offset(); + } + } + + std::vector shape(a.shape().begin(), a.shape().end()); + auto owner = nb::cast(a); + return nb::ndarray<>( + data, + a.ndim(), + shape.data(), + /* owner= */ owner, + a.strides().data(), + nb::dtype(), + dl_device_type, + 0, + '\0', + byte_offset); +} + +nb::ndarray<> mlx_to_dlpack( + const mx::array& a, + std::optional dl_device_type) { + int device_type = dl_device_type.value_or( + mx::metal::is_available() + ? nb::device::metal::value + : (mx::cu::is_available() ? nb::device::cuda_managed::value + : nb::device::cpu::value)); + + if (device_type == nb::device::cuda::value || + device_type == nb::device::cuda_managed::value) { + throw nb::buffer_error("CUDA DLPack export is not supported."); + } + if (device_type != nb::device::cpu::value && + device_type != nb::device::metal::value) { + throw nb::buffer_error( + "Cannot export mlx array to requested DLPack device."); + } + if (device_type == nb::device::metal::value && !mx::metal::is_available()) { + throw nb::buffer_error("Metal DLPack export is not available."); + } + + switch (a.dtype()) { + case mx::bool_: + return mlx_to_dlpack_impl(a, device_type); + case mx::uint8: + return mlx_to_dlpack_impl(a, device_type); + case mx::uint16: + return mlx_to_dlpack_impl(a, device_type); + case mx::uint32: + return mlx_to_dlpack_impl(a, device_type); + case mx::uint64: + return mlx_to_dlpack_impl(a, device_type); + case mx::int8: + return mlx_to_dlpack_impl(a, device_type); + case mx::int16: + return mlx_to_dlpack_impl(a, device_type); + case mx::int32: + return mlx_to_dlpack_impl(a, device_type); + case mx::int64: + return mlx_to_dlpack_impl(a, device_type); + case mx::float16: + return mlx_to_dlpack_impl(a, device_type); + case mx::bfloat16: + return mlx_to_dlpack_impl(a, device_type); + case mx::float32: + return mlx_to_dlpack_impl(a, device_type); + case mx::float64: + return mlx_to_dlpack_impl(a, device_type); + case mx::complex64: + return mlx_to_dlpack_impl>(a, device_type); + default: + throw nb::buffer_error("Cannot export mlx array with unsupported dtype."); + } } nb::object to_scalar(mx::array& a) { diff --git a/python/src/convert.h b/python/src/convert.h index 3ac6cb9f74..bf93540f2a 100644 --- a/python/src/convert.h +++ b/python/src/convert.h @@ -7,7 +7,6 @@ #include #include "mlx/array.h" -#include "mlx/ops.h" namespace mx = mlx::core; namespace nb = nanobind; @@ -67,9 +66,9 @@ mx::array nd_array_to_mlx( std::optional nb_dtype = std::nullopt); nb::ndarray mlx_to_np_array(const mx::array& a); -nb::ndarray<> mlx_to_dlpack(const mx::array& a); - -mx::array host_accessible_array(mx::array a); +nb::ndarray<> mlx_to_dlpack( + const mx::array& a, + std::optional dl_device_type = std::nullopt); nb::object to_scalar(mx::array& a); diff --git a/python/tests/test_array.py b/python/tests/test_array.py index c2acb88bcf..7340e1dad5 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2035,17 +2035,28 @@ def test_add_numpy(self): self.assertEqual(z.item(), 3) def test_dlpack(self): + class CpuDLPack: + def __init__(self, array): + self.array = array + + def __dlpack_device__(self): + return (1, 0) + + def __dlpack__(self, *args, **kwargs): + kwargs["dl_device"] = (1, 0) + return self.array.__dlpack__(*args, **kwargs) + x = mx.array(1, dtype=mx.int32) - y = np.from_dlpack(x) + y = np.from_dlpack(CpuDLPack(x)) self.assertTrue(mx.array_equal(y, x)) x = mx.array([[1.0, 2.0], [3.0, 4.0]]) - y = np.from_dlpack(x) + y = np.from_dlpack(CpuDLPack(x)) self.assertTrue(mx.array_equal(y, x)) x = mx.arange(16).reshape(4, 4) x = x[::2, ::2] - y = np.from_dlpack(x) + y = np.from_dlpack(CpuDLPack(x)) self.assertTrue(mx.array_equal(y, x)) def test_from_dlpack_cpu(self): @@ -2084,14 +2095,19 @@ def test_torch_mps_dlpack_host_access(self): memoryview(y) @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") - def test_torch_mps_dlpack_zero_copy_reads_torch_updates(self): + def test_torch_mps_dlpack_zero_copy_shares_updates(self): x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) + torch.mps.synchronize() y = mx.array(x) x.add_(100) torch.mps.synchronize() self.assertEqual((y + 1).tolist(), (x + 1).cpu().numpy().tolist()) + y += 10 + mx.eval(y) + self.assertEqual(x.cpu().numpy().tolist(), y.tolist()) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_torch_mps_dlpack_dtype_argument_copies(self): x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) @@ -2148,7 +2164,44 @@ def test_torch_mps_dlpack_mlx_update_writes_torch_buffer(self): self.assertEqual(x.cpu().numpy().tolist(), [3, 4, 5]) @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") - def test_from_dlpack_torch_mps_copy_none_shares(self): + def test_mlx_dlpack_exports_mps_tensor_to_torch(self): + x = mx.array([1]).astype(mx.float16) + mx.eval(x) + y = torch.utils.dlpack.from_dlpack(x) + torch.mps.synchronize() + + self.assertEqual(y.device.type, "mps") + self.assertEqual(y.dtype, torch.float16) + self.assertEqual(y.cpu().numpy().tolist(), [1.0]) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_mlx_dlpack_exports_mps_tensor_to_torch_tensor(self): + x = mx.array([1]).astype(mx.float16) + mx.eval(x) + y = torch.tensor(x) + torch.mps.synchronize() + + self.assertEqual(y.device.type, "mps") + self.assertEqual(y.dtype, torch.float16) + self.assertEqual(y.cpu().numpy().tolist(), [1.0]) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_mlx_dlpack_export_torch_update_writes_mlx_buffer(self): + x = mx.arange(8, dtype=mx.float32) + y = x[2:6] + mx.eval(y) + t = torch.utils.dlpack.from_dlpack(y) + + self.assertEqual(t.device.type, "mps") + self.assertEqual(t.cpu().numpy().tolist(), [2.0, 3.0, 4.0, 5.0]) + + t.add_(10) + torch.mps.synchronize() + self.assertEqual(y.tolist(), [12.0, 13.0, 14.0, 15.0]) + self.assertEqual(x.tolist(), [0.0, 1.0, 12.0, 13.0, 14.0, 15.0, 6.0, 7.0]) + + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_from_dlpack_torch_mps_copy_none_shares_updates(self): x = torch.arange(3, device="mps", dtype=torch.float32) torch.mps.synchronize() y = mx.from_dlpack(x) @@ -2157,8 +2210,12 @@ def test_from_dlpack_torch_mps_copy_none_shares(self): torch.mps.synchronize() self.assertEqual(y.tolist(), [10.0, 11.0, 12.0]) + y += 10 + mx.eval(y) + self.assertEqual(x.cpu().numpy().tolist(), [20.0, 21.0, 22.0]) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") - def test_from_dlpack_torch_mps_copy_false_shares(self): + def test_from_dlpack_torch_mps_copy_false_shares_updates(self): x = torch.arange(3, device="mps", dtype=torch.float32) torch.mps.synchronize() y = mx.from_dlpack(x, copy=False) @@ -2167,6 +2224,10 @@ def test_from_dlpack_torch_mps_copy_false_shares(self): mx.eval(y) self.assertEqual(x.cpu().numpy().tolist(), [10.0, 11.0, 12.0]) + x.add_(10) + torch.mps.synchronize() + self.assertEqual(y.tolist(), [20.0, 21.0, 22.0]) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_from_dlpack_torch_mps_copy_true_copies(self): x = torch.arange(3, device="mps", dtype=torch.float32) From 4e16f1d317bcf3c600924cae71dd600ae2fe443e Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Tue, 12 May 2026 01:19:33 +0800 Subject: [PATCH 4/4] Share DLPack arrays when dtype matches --- python/src/convert.cpp | 21 ++++++++++++++++----- python/tests/test_array.py | 23 ++++++++++++++++++----- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/python/src/convert.cpp b/python/src/convert.cpp index d746eab34c..0bb6045203 100644 --- a/python/src/convert.cpp +++ b/python/src/convert.cpp @@ -97,6 +97,13 @@ auto dispatch_dlpack_dtype( } } +mx::Dtype mlx_dtype_from_dlpack( + nb::dlpack::dtype type, + const char* error_message) { + return dispatch_dlpack_dtype( + type, [](mx::Dtype dtype) { return dtype; }, error_message); +} + mx::array metal_dlpack_to_mlx( nb::ndarray nd_array, std::optional dtype); @@ -138,10 +145,8 @@ mx::array from_dlpack(nb::object v, std::optional copy) { case nb::device::metal::value: { std::optional dtype; if (copy == true) { - dtype = dispatch_dlpack_dtype( - nd.dtype(), - [](mx::Dtype dtype) { return dtype; }, - "Cannot convert Metal DLPack array to mlx array."); + dtype = mlx_dtype_from_dlpack( + nd.dtype(), "Cannot convert Metal DLPack array to mlx array."); } return nd_array_to_mlx(std::move(nd), dtype); } @@ -704,7 +709,13 @@ mx::array create_array(nb::object v, std::optional t) { } else { nd = nb::cast(v); } - return nd_array_to_mlx(nd, t, nb_dtype); + auto type = nb_dtype.value_or(nd.dtype()); + std::optional copy_dtype; + if (t && + *t != mlx_dtype_from_dlpack(type, "Cannot convert array to mlx.")) { + copy_dtype = t; + } + return nd_array_to_mlx(nd, copy_dtype, nb_dtype); } else { auto arr = to_array_with_accessor(v); return mx::astype(arr, t.value_or(arr.dtype())); diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 7340e1dad5..f518cbaa16 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2109,19 +2109,32 @@ def test_torch_mps_dlpack_zero_copy_shares_updates(self): self.assertEqual(x.cpu().numpy().tolist(), y.tolist()) @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") - def test_torch_mps_dlpack_dtype_argument_copies(self): + def test_torch_mps_dlpack_matching_dtype_argument_shares_updates(self): x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) torch.mps.synchronize() - y_copy = mx.array(x, dtype=mx.float32) - expected = x.cpu().numpy().tolist() + y = mx.array(x, dtype=mx.float32) x.add_(100) torch.mps.synchronize() - self.assertEqual(y_copy.tolist(), expected) + self.assertEqual(y.tolist(), x.cpu().numpy().tolist()) + + y += 10 + mx.eval(y) + self.assertEqual(x.cpu().numpy().tolist(), y.tolist()) + @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") + def test_torch_mps_dlpack_different_dtype_argument_copies(self): + x = torch.arange(12, device="mps", dtype=torch.float32).reshape(3, 4) + torch.mps.synchronize() z = mx.array(x, dtype=mx.float16) + expected = x.to(torch.float16).cpu().numpy().tolist() + self.assertEqual(z.dtype, mx.float16) - self.assertEqual(z.tolist(), x.to(torch.float16).cpu().numpy().tolist()) + self.assertEqual(z.tolist(), expected) + + x.add_(100) + torch.mps.synchronize() + self.assertEqual(z.tolist(), expected) @unittest.skipUnless(has_torch_mps, "PyTorch MPS is required") def test_torch_mps_dlpack_data_offset(self):