diff --git a/arrow-buffer/src/buffer/mutable.rs b/arrow-buffer/src/buffer/mutable.rs index b6e6a70c6cba..f6e537de7565 100644 --- a/arrow-buffer/src/buffer/mutable.rs +++ b/arrow-buffer/src/buffer/mutable.rs @@ -841,11 +841,12 @@ impl MutableBuffer { /// /// This claims the memory used by this buffer in the pool, allowing for /// accurate accounting of memory usage. Any prior reservation will be - /// released so this works well when the buffer is being shared among - /// multiple arrays. + /// dropped before creating a new one to avoid transient double-counting. #[cfg(feature = "pool")] pub fn claim(&self, pool: &dyn MemoryPool) { - *self.reservation.lock().unwrap() = Some(pool.reserve(self.capacity())); + let mut guard = self.reservation.lock().unwrap(); + guard.take(); + *guard = Some(pool.reserve(self.capacity())); } } diff --git a/arrow-buffer/src/bytes.rs b/arrow-buffer/src/bytes.rs index a80a347fa17a..5aac5ec0ebc3 100644 --- a/arrow-buffer/src/bytes.rs +++ b/arrow-buffer/src/bytes.rs @@ -108,9 +108,14 @@ impl Bytes { } /// Register this [`Bytes`] with the provided [`MemoryPool`], replacing any prior reservation. + /// + /// This drops any existing reservation before creating the new one to + /// avoid transient double-counting of memory in the pool. #[cfg(feature = "pool")] pub(crate) fn claim(&self, pool: &dyn MemoryPool) { - *self.reservation.lock().unwrap() = Some(pool.reserve(self.capacity())); + let mut guard = self.reservation.lock().unwrap(); + guard.take(); + *guard = Some(pool.reserve(self.capacity())); } /// Resize the memory reservation of this buffer @@ -242,6 +247,12 @@ impl From for Bytes { #[cfg(test)] mod tests { use super::*; + #[cfg(feature = "pool")] + use crate::pool::{MemoryPool, MemoryReservation}; + #[cfg(feature = "pool")] + use std::sync::Arc; + #[cfg(feature = "pool")] + use std::sync::atomic::{AtomicUsize, Ordering}; #[test] fn test_from_bytes() { @@ -258,12 +269,130 @@ mod tests { assert_eq!(a_bytes.as_slice(), message); } + /// A pool wrapper that records the maximum [`MemoryPool::used`] value + /// observed during any [`MemoryPool::reserve`] call, allowing tests to + /// detect transient double-counting caused by failing to drop an existing + /// reservation before creating a new one. + #[cfg(feature = "pool")] + #[derive(Debug)] + struct MaxTrackerPool { + shared: Arc, + max_used: Arc, + } + + #[cfg(feature = "pool")] + impl MaxTrackerPool { + fn new() -> Self { + Self { + shared: Arc::new(AtomicUsize::new(0)), + max_used: Arc::new(AtomicUsize::new(0)), + } + } + + fn used(&self) -> usize { + self.shared.load(Ordering::Relaxed) + } + + fn max_used(&self) -> usize { + self.max_used.load(Ordering::Relaxed) + } + } + + #[cfg(feature = "pool")] + impl MemoryPool for MaxTrackerPool { + fn reserve(&self, size: usize) -> Box { + self.shared.fetch_add(size, Ordering::Relaxed); + let current = self.shared.load(Ordering::Relaxed); + self.max_used.fetch_max(current, Ordering::Relaxed); + Box::new(MaxTracker { + size, + shared: Arc::clone(&self.shared), + }) + } + + fn available(&self) -> isize { + isize::MAX - self.used() as isize + } + + fn used(&self) -> usize { + self.shared.load(Ordering::Relaxed) + } + + fn capacity(&self) -> usize { + usize::MAX + } + } + + #[cfg(feature = "pool")] + #[derive(Debug)] + struct MaxTracker { + size: usize, + shared: Arc, + } + + #[cfg(feature = "pool")] + impl Drop for MaxTracker { + fn drop(&mut self) { + self.shared.fetch_sub(self.size, Ordering::Relaxed); + } + } + + #[cfg(feature = "pool")] + impl MemoryReservation for MaxTracker { + fn size(&self) -> usize { + self.size + } + + fn resize(&mut self, new: usize) { + match self.size < new { + true => self.shared.fetch_add(new - self.size, Ordering::Relaxed), + false => self.shared.fetch_sub(self.size - new, Ordering::Relaxed), + }; + self.size = new; + } + } + #[cfg(feature = "pool")] mod pool_tests { use super::*; use crate::pool::TrackingMemoryPool; + #[test] + fn test_claim_does_not_double_count() { + // Verifies that claiming an already-claimed buffer does not + // transiently double-count memory in the pool. The MaxTrackerPool + // records the maximum used() value seen during any reserve() call. + let buffer = unsafe { + let layout = + std::alloc::Layout::from_size_align(1024, crate::alloc::ALIGNMENT).unwrap(); + let ptr = std::alloc::alloc(layout); + assert!(!ptr.is_null()); + Bytes::new( + NonNull::new(ptr).unwrap(), + 1024, + Deallocation::Standard(layout), + ) + }; + + let pool = MaxTrackerPool::new(); + assert_eq!(pool.used(), 0); + + // First claim + buffer.claim(&pool); + assert_eq!(pool.used(), 1024); + assert_eq!(pool.max_used(), 1024); + + // Second claim — without the fix this peaks at 2048 because + // reserve() is called while the old reservation is still live. + buffer.claim(&pool); + assert_eq!(pool.used(), 1024); + assert_eq!(pool.max_used(), 1024); + + drop(buffer); + assert_eq!(pool.used(), 0); + } + #[test] fn test_bytes_with_pool() { // Create a standard allocation