From e5aa760e1c630324de88615de4ed840724478a2c Mon Sep 17 00:00:00 2001 From: Akhil Mithran <97193607+Akhil-CM@users.noreply.github.com> Date: Wed, 4 Dec 2024 02:00:25 +0100 Subject: [PATCH] [data/air] handle numpy > 2.0.0 behaviour in _create_possibly_ragged_ndarray (#48064) ## Why are these changes needed? Fixes https://github.com/ray-project/ray/issues/47711 Also, takes into consideration new `copy` behaviour as mentioned in https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword ## Related issue number ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run `scripts/format.sh` to lint the changes in this PR. - [x] I've included any doc changes needed for https://docs.ray.io/en/master/. - [x] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [ ] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Akhil Mithran <97193607+Akhil-CM@users.noreply.github.com> Signed-off-by: Richard Liaw Co-authored-by: Richard Liaw --- python/ray/air/util/tensor_extensions/utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/ray/air/util/tensor_extensions/utils.py b/python/ray/air/util/tensor_extensions/utils.py index 410f4c3a2a44..aa040b11b5e5 100644 --- a/python/ray/air/util/tensor_extensions/utils.py +++ b/python/ray/air/util/tensor_extensions/utils.py @@ -46,8 +46,16 @@ def _create_possibly_ragged_ndarray( # `np.array(...)` without the `dtype=object` parameter will raise a # VisibleDeprecationWarning which we suppress. # More details: https://stackoverflow.com/q/63097829 - warnings.simplefilter("ignore", category=np.VisibleDeprecationWarning) - return np.array(values, copy=False) + if np.lib.NumpyVersion(np.__version__) >= "2.0.0": + copy_if_needed = None + warning_type = np.exceptions.VisibleDeprecationWarning + else: + copy_if_needed = False + warning_type = np.VisibleDeprecationWarning + + warnings.simplefilter("ignore", category=warning_type) + arr = np.array(values, copy=copy_if_needed) + return arr except ValueError as e: # Constructing a ragged ndarray directly via `np.array(...)` # without the `dtype=object` parameter will raise a ValueError.