From 7e60107f5ee31addd1d0dfdd9113625a9919a200 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ey=C3=BCp=20Can=20Akman?= Date: Sun, 10 May 2026 12:22:06 +0300 Subject: [PATCH] Add `copy` keyword to `mx.asarray` --- python/src/ops.cpp | 18 +++++++++++++++--- python/tests/test_array.py | 11 +++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 9a48b37afe..f0d61f6449 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1765,23 +1765,35 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "asarray", - [](const nb::object& a, std::optional dtype) { + [](const nb::object& a, + std::optional dtype, + std::optional copy) { + if (copy.has_value() && !*copy) { + throw std::invalid_argument("[asarray] copy=False is not supported."); + } return create_array(a, dtype); }, nb::arg(), "dtype"_a = nb::none(), + nb::kw_only(), + "copy"_a = nb::none(), nb::sig( - "def asarray(a: Union[scalar, array, Sequence], dtype: " - "Optional[Dtype] = None) -> array"), + "def asarray(a: Union[scalar, array, Sequence], dtype: Optional[Dtype] = None, *, copy: Optional[bool] = None) -> array"), R"pbdoc( Convert the input to an array. Args: a: Input data. dtype (Dtype, optional): The desired data-type for the array. + copy (bool, optional): Must be ``True`` or unspecified. ``False`` + is not supported, since MLX has no in-place operations and + cannot return a non-copying view. Returns: array: An array interpretation of the input. + + Raises: + ValueError: If ``copy`` is ``False``. )pbdoc"); m.def( "zeros_like", diff --git a/python/tests/test_array.py b/python/tests/test_array.py index bc50b3d768..f2c2000234 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -2152,6 +2152,17 @@ def test_array_namespace_asarray(self): arr_pass = xp.asarray(existing) self.assertEqual(arr_pass.tolist(), [4, 5, 6]) + def test_asarray_copy(self): + existing = mx.array([1, 2, 3]) + + self.assertEqual(mx.asarray(existing, copy=True).tolist(), [1, 2, 3]) + self.assertEqual( + mx.asarray(existing, dtype=mx.float32, copy=True).dtype, mx.float32 + ) + + with self.assertRaises(ValueError): + mx.asarray(existing, copy=False) + def test_asarray(self): # List inputs self.assertEqual(mx.asarray([1, 2, 3]).tolist(), [1, 2, 3])