diff --git a/python/ray/air/util/tensor_extensions/utils.py b/python/ray/air/util/tensor_extensions/utils.py index 410f4c3a2a44..aa040b11b5e5 100644 --- a/python/ray/air/util/tensor_extensions/utils.py +++ b/python/ray/air/util/tensor_extensions/utils.py @@ -46,8 +46,16 @@ def _create_possibly_ragged_ndarray( # `np.array(...)` without the `dtype=object` parameter will raise a # VisibleDeprecationWarning which we suppress. # More details: https://stackoverflow.com/q/63097829 - warnings.simplefilter("ignore", category=np.VisibleDeprecationWarning) - return np.array(values, copy=False) + if np.lib.NumpyVersion(np.__version__) >= "2.0.0": + copy_if_needed = None + warning_type = np.exceptions.VisibleDeprecationWarning + else: + copy_if_needed = False + warning_type = np.VisibleDeprecationWarning + + warnings.simplefilter("ignore", category=warning_type) + arr = np.array(values, copy=copy_if_needed) + return arr except ValueError as e: # Constructing a ragged ndarray directly via `np.array(...)` # without the `dtype=object` parameter will raise a ValueError.