diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c92f86b..8b79993 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -106,7 +106,7 @@ jobs: command: clippy args: -- -D warnings - name: Install dependencies for tools - run: sudo apt-get -y install libfontconfig1-dev jq + run: sudo apt-get update && sudo apt-get -y install libfontconfig1-dev jq - name: Check tools working-directory: tools run: cargo clippy -- -D warnings diff --git a/README.md b/README.md index aeef278..8c986e8 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ Lightweight and high performance concurrent cache optimized for low cache overhe * Scales well with the number of threads * Atomic operations with `get_or_insert` and `get_value_or_guard` functions * Atomic async operations with `get_or_insert_async` and `get_value_or_guard_async` functions +* Closure-based `entry` API for atomic inspect-and-act patterns (keep, remove, replace) * Supports item pinning * Iteration and draining * Handles zero weight items efficiently @@ -50,7 +51,7 @@ struct StringWeighter; impl Weighter for StringWeighter { fn weight(&self, _key: &u64, val: &String) -> u64 { - // Be cautions out about zero weights! + // Be cautious about zero weights! val.len() as u64 } } @@ -64,6 +65,39 @@ fn main() { } ``` +Atomic inspect-and-act with the `entry` API + +```rust +use quick_cache::sync::{Cache, EntryAction, EntryResult}; + +fn main() { + let cache: Cache = Cache::new(100); + + // Insert-or-get: if absent, compute and insert; if present, return cached + let result = cache.entry(&0, None, |_key, val| EntryAction::Retain(*val)); + let value = match result { + EntryResult::Retained(v) => v, + EntryResult::Vacant(guard) => { + let v = 42; // expensive computation + guard.insert(v).unwrap(); + v + } + _ => unreachable!(), + }; + assert_eq!(value, 42); + + // Conditionally remove: evict entries below a threshold + let result = cache.entry(&0, None, |_key, val| { + if *val < 100 { + EntryAction::<()>::Remove + } else { + EntryAction::Retain(()) + } + }); + assert!(matches!(result, EntryResult::Removed(0, 42))); +} +``` + Using the `Equivalent` trait for complex keys ```rust diff --git a/src/lib.rs b/src/lib.rs index 75afb3c..4ca92d3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,8 +17,8 @@ //! # Equivalent keys //! //! The cache uses the [`Equivalent`](https://docs.rs/equivalent/1.0.1/equivalent/trait.Equivalent.html) trait -//! for gets/removals. It can helps work around the `Borrow` limitations. -//! For example, if the cache key is a tuple `(K, Q)`, you wouldn't be access to access such keys without +//! for gets/removals. It can help work around the `Borrow` limitations. +//! For example, if the cache key is a tuple `(K, Q)`, you wouldn't be able to access such keys without //! building a `&(K, Q)` and thus potentially cloning `K` and/or `Q`. //! //! # User defined weight @@ -31,6 +31,10 @@ //! are available, they can be mix and matched) the user can coordinate the insertion of entries, so only //! one value is "computed" and inserted after a cache miss. //! +//! The `entry` family of functions provide a closure-based API for atomically +//! inspecting and acting on existing entries (keep, remove, or replace) while also coordinating +//! insertion on cache misses. +//! //! # Lifecycle hooks //! //! A user can optionally provide a custom [Lifecycle] implementation to hook into the lifecycle of cache entries. @@ -114,7 +118,7 @@ pub type DefaultHashBuilder = std::collections::hash_map::RandomState; /// /// impl Weighter for StringWeighter { /// fn weight(&self, _key: &u64, val: &String) -> u64 { -/// // Be cautious out about zero weights! +/// // Be cautious about zero weights! /// val.len() as u64 /// } /// } diff --git a/src/options.rs b/src/options.rs index 1edd5dc..ad527ae 100644 --- a/src/options.rs +++ b/src/options.rs @@ -122,7 +122,7 @@ impl OptionsBuilder { self } - /// Builds an `Option` struct which can be used in the `Cache::with_options` constructor. + /// Builds an `Options` struct which can be used in the `Cache::with_options` constructor. #[inline] pub fn build(&self) -> Result { let shards = self.shards.unwrap_or_else(|| available_parallelism() * 4); diff --git a/src/shard.rs b/src/shard.rs index acc3f4d..ec40b94 100644 --- a/src/shard.rs +++ b/src/shard.rs @@ -31,6 +31,42 @@ pub enum InsertStrategy { Replace { soft: bool }, } +/// What to do with an existing entry after inspection/mutation. +/// +/// Used with [`Cache::entry`](crate::sync::Cache::entry) and +/// [`Cache::entry_async`](crate::sync::Cache::entry_async). +pub enum EntryAction { + /// Retain the entry in the cache. The value may have been mutated in place + /// before returning this variant. + /// + /// Returns [`EntryResult::Retained(T)`](crate::sync::EntryResult::Retained). + Retain(T), + /// Remove the entry from the cache. + /// + /// Returns [`EntryResult::Removed(Key, Val)`](crate::sync::EntryResult::Removed). + Remove, + /// Remove the entry and get a [`PlaceholderGuard`](crate::sync::PlaceholderGuard) + /// for re-insertion. This is useful for "validate-or-recompute" patterns. + /// + /// Returns [`EntryResult::Replaced(PlaceholderGuard, Val)`](crate::sync::EntryResult::Replaced). + ReplaceWithGuard, +} + +/// Result of an entry-or-placeholder operation at the shard level. +pub enum EntryOrPlaceholder { + /// Callback returned `Retain(T)` — entry is still in the cache. + Kept(T), + /// Callback returned `Remove` — entry was removed. + Removed(Key, Val), + /// Callback returned `ReplaceWithGuard` — entry was replaced with a placeholder. + /// The old value is returned so it can be dropped outside the lock. + Replaced(Plh, Val), + /// Found an existing placeholder (another loader is working on this key). + ExistingPlaceholder(Plh), + /// No entry existed, a new placeholder was created. + NewPlaceholder(Plh), +} + #[derive(Copy, Clone, Debug, PartialEq, Eq)] enum ResidentState { Hot, @@ -514,7 +550,7 @@ impl< .is_some() } - pub fn get(&self, hash: u64, key: &Q) -> Option<&Val> + pub fn get_key_value(&self, hash: u64, key: &Q) -> Option<(&Key, &Val)> where Q: Hash + Equivalent + ?Sized, { @@ -527,13 +563,21 @@ impl< resident.referenced.fetch_add(1, atomic::Ordering::Relaxed); } record_hit!(self); - Some(&resident.value) + Some((&resident.key, &resident.value)) } else { record_miss!(self); None } } + #[inline] + pub fn get(&self, hash: u64, key: &Q) -> Option<&Val> + where + Q: Hash + Equivalent + ?Sized, + { + self.get_key_value(hash, key).map(|(_k, v)| v) + } + pub fn get_mut(&mut self, hash: u64, key: &Q) -> Option> where Q: Hash + Equivalent + ?Sized, @@ -553,9 +597,12 @@ impl< let old_weight = self.weighter.weight(&resident.key, &resident.value); Some(RefMut { - idx, - old_weight, - cache: self, + guard: WeightGuard { + shard: self as *mut _, + idx, + old_weight, + }, + _phantom: std::marker::PhantomData, }) } @@ -573,9 +620,12 @@ impl< if let Some((Entry::Resident(resident), _)) = self.entries.get_mut(token) { let old_weight = self.weighter.weight(&resident.key, &resident.value); Some(RefMut { - old_weight, - idx: token, - cache: self, + guard: WeightGuard { + shard: self as *mut _, + idx: token, + old_weight, + }, + _phantom: std::marker::PhantomData, }) } else { None @@ -1082,7 +1132,7 @@ impl< Ok(()) } - pub fn upsert_placeholder( + pub fn get_or_placeholder( &mut self, hash: u64, key: &Q, @@ -1090,28 +1140,143 @@ impl< where Q: Hash + Equivalent + ToOwned + ?Sized, { - let shared; - if let Some(idx) = self.search(hash, key) { - let (entry, _) = self.entries.get_mut(idx).unwrap(); - match entry { - Entry::Resident(resident) => { - if *resident.referenced.get_mut() < MAX_F { - *resident.referenced.get_mut() += 1; + let idx = self.search(hash, key); + if let Some(idx) = idx { + if let Some((Entry::Resident(resident), _)) = self.entries.get_mut(idx) { + if *resident.referenced.get_mut() < MAX_F { + *resident.referenced.get_mut() += 1; + } + record_hit_mut!(self); + unsafe { + // Rustc gets insanely confused returning references from mut borrows + // Safety: value will have the same lifetime as `resident` + let value_ptr: *const Val = &resident.value; + return Ok((idx, &*value_ptr)); + } + } + } + let (shared, is_new) = unsafe { self.non_resident_to_placeholder(hash, key, idx) }; + Err((shared, is_new)) + } + + /// Entry operation on an existing or missing key. + /// + /// If a `Resident` entry exists, calls `on_occupied` with `&mut Val` to decide what to do. + /// On `Retain`, weight is recalculated after the callback returns. + /// Otherwise, creates a placeholder or joins an existing one. + /// + /// `on_occupied` is taken by mutable reference so the caller retains ownership. + /// It is called at most once per invocation. + pub fn entry_or_placeholder( + &mut self, + hash: u64, + key: &Q, + on_occupied: &mut F, + ) -> EntryOrPlaceholder + where + Q: Hash + Equivalent + ToOwned + ?Sized, + F: FnMut(&Key, &mut Val) -> EntryAction, + { + let idx = self.search(hash, key); + if let Some(idx) = idx { + let shard = self as *mut _; + if let Some((Entry::Resident(r), _)) = self.entries.get_mut(idx) { + // Call the callback inside a WeightGuard scope: if it panics + // after mutating the value, the guard recomputes weight on drop. + // SAFETY: key/value point into the Resident entry at idx, alive + // for the duration of the callback. + let action = { + let (key_ptr, val_ptr) = (&r.key as *const Key, &mut r.value as *mut Val); + let _guard = WeightGuard:: { + idx, + old_weight: self.weighter.weight(&r.key, &r.value), + shard, + }; + on_occupied(unsafe { &*key_ptr }, unsafe { &mut *val_ptr }) + }; + + return match action { + EntryAction::Retain(t) => { + record_hit_mut!(self); + let Some((Entry::Resident(resident), _)) = self.entries.get_mut(idx) else { + // SAFETY: we had a mut reference to the Resident under `idx` until the previous line + unsafe { unreachable_unchecked() }; + }; + if *resident.referenced.get_mut() < MAX_F { + *resident.referenced.get_mut() += 1; + } + EntryOrPlaceholder::Kept(t) } - record_hit_mut!(self); - unsafe { - // Rustc gets insanely confused returning references from mut borrows - // Safety: value will have the same lifetime as `resident` - let value_ptr: *const Val = &resident.value; - return Ok((idx, &*value_ptr)); + EntryAction::Remove => { + let (key, val) = self.remove_internal(hash, idx).unwrap(); + EntryOrPlaceholder::Removed(key, val) } - } + EntryAction::ReplaceWithGuard => { + let Some((Entry::Resident(r), _)) = self.entries.get_mut(idx) else { + // SAFETY: we had a mut reference to the Resident under `idx` until the previous line + unsafe { unreachable_unchecked() }; + }; + let state = r.state; + let current_weight = self.weighter.weight(&r.key, &r.value); + let list_head = if state == ResidentState::Hot { + self.num_hot -= 1; + self.weight_hot -= current_weight; + &mut self.hot_head + } else { + self.num_cold -= 1; + self.weight_cold -= current_weight; + &mut self.cold_head + }; + if current_weight != 0 { + let next = self.entries.unlink(idx); + if *list_head == Some(idx) { + *list_head = next; + } + } + let shared = Plh::new(hash, idx); + let (entry, _) = unsafe { self.entries.get_mut_unchecked(idx) }; + let Entry::Resident(r) = mem::replace(entry, Entry::Ghost(0)) else { + unsafe { unreachable_unchecked() } + }; + *entry = Entry::Placeholder(Placeholder { + key: r.key, + hot: state, + shared: shared.clone(), + }); + EntryOrPlaceholder::Replaced(shared, r.value) + } + }; + } + } + let (shared, is_new) = unsafe { self.non_resident_to_placeholder(hash, key, idx) }; + if is_new { + EntryOrPlaceholder::NewPlaceholder(shared) + } else { + EntryOrPlaceholder::ExistingPlaceholder(shared) + } + } + + /// Creates or joins a placeholder for a non-Resident entry (Placeholder, Ghost, or missing). + /// Returns `(shared, true)` for new placeholders, `(shared, false)` for existing ones. + /// The entry at `idx` must NOT be Resident. + unsafe fn non_resident_to_placeholder( + &mut self, + hash: u64, + key: &Q, + idx: Option, + ) -> (Plh, bool) + where + Q: Hash + Equivalent + ToOwned + ?Sized, + { + if let Some(idx) = idx { + let (entry, _) = unsafe { self.entries.get_mut_unchecked(idx) }; + match entry { Entry::Placeholder(p) => { record_hit_mut!(self); - return Err((p.shared.clone(), false)); + (p.shared.clone(), false) } Entry::Ghost(_) => { - shared = Plh::new(hash, idx); + let shared = Plh::new(hash, idx); *entry = Entry::Placeholder(Placeholder { key: key.to_owned(), hot: ResidentState::Hot, @@ -1122,11 +1287,14 @@ impl< if self.ghost_head == Some(idx) { self.ghost_head = next; } + record_miss_mut!(self); + (shared, true) } + Entry::Resident(_) => unsafe { unreachable_unchecked() }, } } else { let idx = self.entries.next_free(); - shared = Plh::new(hash, idx); + let shared = Plh::new(hash, idx); let idx_ = self.entries.insert(Entry::Placeholder(Placeholder { key: key.to_owned(), hot: ResidentState::Cold, @@ -1134,9 +1302,9 @@ impl< })); debug_assert_eq!(idx, idx_); self.map_insert(hash, idx); + record_miss_mut!(self); + (shared, true) } - record_miss_mut!(self); - Err((shared, true)) } pub fn set_capacity(&mut self, new_weight_capacity: u64) { @@ -1170,55 +1338,66 @@ impl< } } -/// Structure wrapping a mutable reference to a cached item. -pub struct RefMut<'cache, Key, Val, We: Weighter, B, L, Plh: SharedPlaceholder> { - cache: &'cache mut CacheShard, +/// Drop guard for `entry_or_placeholder`: if the user callback panics after +/// mutating the value, recomputes weight to keep shard accounting consistent. +struct WeightGuard, B, L, Plh: SharedPlaceholder> { + shard: *mut CacheShard, idx: Token, old_weight: u64, } +impl, B, L, Plh: SharedPlaceholder> Drop + for WeightGuard +{ + fn drop(&mut self) { + // SAFETY: shard pointer is valid — guard is created and dropped within + // entry_or_placeholder which holds &mut CacheShard. + unsafe { + let shard = &mut *self.shard; + let (entry, _) = shard.entries.get_unchecked(self.idx); + let Entry::Resident(r) = entry else { + unreachable_unchecked() + }; + let new_weight = shard.weighter.weight(&r.key, &r.value); + if self.old_weight != new_weight { + shard.cold_change_weight(self.idx, self.old_weight, new_weight); + } + } + } +} + +/// Structure wrapping a mutable reference to a cached item. +/// On drop, recomputes weight via the inner [`WeightGuard`]. +pub struct RefMut<'cache, Key, Val, We: Weighter, B, L, Plh: SharedPlaceholder> { + guard: WeightGuard, + _phantom: std::marker::PhantomData<&'cache mut CacheShard>, +} + impl, B, L, Plh: SharedPlaceholder> RefMut<'_, Key, Val, We, B, L, Plh> { pub(crate) fn pair(&self) -> (&Key, &Val) { - // Safety: RefMut was constructed correctly from a Resident entry in get_mut or peek_token_mut - // and it couldn't be modified as we're holding a mutable reference to the cache + // SAFETY: RefMut is only constructed from a valid &mut CacheShard with a + // Resident entry at idx, and we hold exclusive access via the lifetime. unsafe { - if let (Entry::Resident(Resident { key, value, .. }), _) = - self.cache.entries.get_unchecked(self.idx) - { - (key, value) - } else { + let shard = &*self.guard.shard; + let (entry, _) = shard.entries.get_unchecked(self.guard.idx); + let Entry::Resident(Resident { key, value, .. }) = entry else { core::hint::unreachable_unchecked() - } + }; + (key, value) } } pub(crate) fn value_mut(&mut self) -> &mut Val { - // Safety: RefMut was constructed correctly from a Resident entry in get_mut or peek_token_mut - // and it couldn't be modified as we're holding a mutable reference to the cache + // SAFETY: same as pair(), plus we have &mut self so exclusive access is guaranteed. unsafe { - if let (Entry::Resident(Resident { value, .. }), _) = - self.cache.entries.get_mut_unchecked(self.idx) - { - value - } else { + let shard = &mut *self.guard.shard; + let (entry, _) = shard.entries.get_mut_unchecked(self.guard.idx); + let Entry::Resident(Resident { value, .. }) = entry else { core::hint::unreachable_unchecked() - } - } - } -} - -impl, B, L, Plh: SharedPlaceholder> Drop - for RefMut<'_, Key, Val, We, B, L, Plh> -{ - #[inline] - fn drop(&mut self) { - let (key, value) = self.pair(); - let new_weight = self.cache.weighter.weight(key, value); - if self.old_weight != new_weight { - self.cache - .cold_change_weight(self.idx, self.old_weight, new_weight); + }; + value } } } diff --git a/src/shuttle_tests.rs b/src/shuttle_tests.rs index 1d9960f..21ff933 100644 --- a/src/shuttle_tests.rs +++ b/src/shuttle_tests.rs @@ -6,7 +6,7 @@ use crate::{ sync::{self, atomic, Arc}, thread, }, - sync::GuardResult, + sync::{EntryAction, EntryResult, GuardResult}, }; use shuttle::{ @@ -216,3 +216,229 @@ fn test_waker_change_race_stub() { } }); } + +#[test] +fn test_entry_works() { + let mut config = shuttle::Config::default(); + config.max_steps = shuttle::MaxSteps::None; + let check_determinism = std::env::var("CHECK_DETERMINISM").is_ok_and(|s| !s.is_empty()); + if let Ok(seed) = std::env::var("SEED") { + let seed = std::fs::read_to_string(&seed).unwrap_or(seed.clone()); + let scheduler = shuttle::scheduler::ReplayScheduler::new_from_encoded(&seed); + let runner = shuttle::Runner::new(scheduler, config); + runner.run(test_entry_works_stub); + } else { + let max_iterations: usize = std::env::var("MAX_ITERATIONS") + .map(|s| s.parse().unwrap()) + .unwrap_or(1000); + let scheduler = shuttle::scheduler::RandomScheduler::new(max_iterations); + if check_determinism { + let scheduler = + shuttle::scheduler::UncontrolledNondeterminismCheckScheduler::new(scheduler); + let runner = shuttle::Runner::new(scheduler, config); + runner.run(test_entry_works_stub); + } else { + let runner = shuttle::Runner::new(scheduler, config); + runner.run(test_entry_works_stub); + } + } +} + +fn test_entry_works_stub() { + shuttle::future::block_on(test_entry_works_stub_async()) +} + +async fn test_entry_works_stub_async() { + const PAIRS: usize = 10; + let entered_: Arc = Arc::new(atomic::AtomicUsize::default()); + let cache_ = Arc::new(crate::sync::Cache::::new(100)); + let wg = Arc::new(tokio::sync::Barrier::new(PAIRS)); + let sync_wg = Arc::new(sync::Barrier::new(PAIRS)); + let solve_at = rand::thread_rng().gen_range(0..100); + let mut tasks = Vec::new(); + let mut threads = Vec::new(); + for _ in 0..PAIRS { + let cache = cache_.clone(); + let wg = wg.clone(); + let entered = entered_.clone(); + let task = spawn(async move { + wg.wait().await; + loop { + let yields = rand::thread_rng().gen_range(0..PAIRS * 2); + // a dummy timeout like future to race with the cache future in a select + let timeout_fut = std::pin::pin!(async { + for _ in 0..yields { + shuttle::future::yield_now().await; + } + }); + let action = rand::thread_rng().gen_range(0..3u8); + let cache_fut = std::pin::pin!(cache.entry_async(&0, move |_k, v| { + // Always keep the terminal value to ensure termination + if *v == 1 { + return EntryAction::Retain(*v); + } + match action { + 0 => EntryAction::Retain(*v), + 1 => EntryAction::Remove, + _ => EntryAction::ReplaceWithGuard, + } + })); + let cache_fut_res = tokio::select! { + biased; + _ = timeout_fut => { + if rand::thread_rng().gen_bool(0.1) { + cache.insert(0, 0); + } + continue; + }, + result = cache_fut => result, + }; + match cache_fut_res { + EntryResult::Retained(v) => { + if v == 1 { + break; + } + shuttle::future::yield_now().await; + if rand::thread_rng().gen_bool(0.5) { + cache.remove(&0); + } + } + EntryResult::Removed(_, _) => { + shuttle::future::yield_now().await; + } + EntryResult::Vacant(g) | EntryResult::Replaced(g, _) => { + shuttle::future::yield_now().await; + let before = entered.fetch_add(1, atomic::Ordering::Relaxed); + if before >= solve_at { + let _ = g.insert(1); + } + } + EntryResult::Timeout => unreachable!(), + } + } + }); + tasks.push(task); + + let cache = cache_.clone(); + let wg = sync_wg.clone(); + let entered = entered_.clone(); + let thread = thread::spawn(move || { + wg.wait(); + loop { + // note that the actual duration is ignored during shuttle tests + let timeout = match rand::thread_rng().gen_range(0..3) { + 0 => None, + 1 => Some(std::time::Duration::default()), + _ => Some(std::time::Duration::from_millis(100)), + }; + let action = rand::thread_rng().gen_range(0..3u8); + match cache.entry(&0, timeout, |_k, v| { + if *v == 1 { + return EntryAction::Retain(*v); + } + match action { + 0 => EntryAction::Retain(*v), + 1 => EntryAction::Remove, + _ => EntryAction::ReplaceWithGuard, + } + }) { + EntryResult::Retained(v) => { + if v == 1 { + break; + } + shuttle::thread::yield_now(); + if rand::thread_rng().gen_bool(0.5) { + cache.remove(&0); + } + } + EntryResult::Removed(_, _) => { + shuttle::thread::yield_now(); + } + EntryResult::Vacant(g) | EntryResult::Replaced(g, _) => { + shuttle::thread::yield_now(); + let before = entered.fetch_add(1, atomic::Ordering::Relaxed); + if before >= solve_at { + let _ = g.insert(1); + } + } + EntryResult::Timeout => { + if rand::thread_rng().gen_bool(0.1) { + cache.insert(0, 0); + } + } + } + } + }); + threads.push(thread); + } + for task in tasks { + task.await.unwrap(); + } + for thread in threads { + thread.join().unwrap(); + } + assert_eq!(cache_.get(&0), Some(1)); +} + +#[test] +fn test_entry_waker_change_race() { + let mut config = shuttle::Config::default(); + config.max_steps = shuttle::MaxSteps::None; + if let Ok(seed) = std::env::var("SEED") { + let seed = std::fs::read_to_string(&seed).unwrap_or(seed.clone()); + let scheduler = shuttle::scheduler::ReplayScheduler::new_from_encoded(&seed); + let runner = shuttle::Runner::new(scheduler, config); + runner.run(test_entry_waker_change_race_stub); + } else { + let max_iterations: usize = std::env::var("MAX_ITERATIONS") + .map(|s| s.parse().unwrap()) + .unwrap_or(1000); + let scheduler = shuttle::scheduler::RandomScheduler::new(max_iterations); + let runner = shuttle::Runner::new(scheduler, config); + runner.run(test_entry_waker_change_race_stub); + } +} + +fn test_entry_waker_change_race_stub() { + let cache = Arc::new(crate::sync::Cache::::new(100)); + + // Acquire a placeholder guard via entry() on a vacant key. + let guard = match cache.entry(&0, None, |_k, _v| -> EntryAction<()> { unreachable!() }) { + EntryResult::Vacant(g) => g, + _ => unreachable!(), + }; + + // Create entry_async future — will find existing placeholder and wait. + // When the value arrives, entry_async loops back and the callback runs. + let mut fut = std::pin::pin!(cache.entry_async(&0, |_k, v| EntryAction::Retain(*v))); + + // First poll with waker W1 → Pending (registered in waiters list). + let w1 = noop_waker(1); + let mut cx1 = std::task::Context::from_waker(&w1); + assert!(fut.as_mut().poll(&mut cx1).is_pending()); + + // Scoped thread: insert value via guard while re-polling with different waker. + thread::scope(|s| { + s.spawn(|| { + let _ = guard.insert(42); + }); + + // Re-poll with a different waker W2 — exercises the will_wake() == false path. + let w2 = noop_waker(2); + let mut cx2 = std::task::Context::from_waker(&w2); + loop { + match fut.as_mut().poll(&mut cx2) { + Poll::Ready(result) => { + match result { + EntryResult::Retained(v) => assert_eq!(v, 42), + _ => panic!("expected EntryResult::Retained"), + } + break; + } + Poll::Pending => { + shuttle::thread::yield_now(); + } + } + } + }); +} diff --git a/src/sync.rs b/src/sync.rs index 6a3366e..df77b40 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -14,7 +14,9 @@ use crate::{ DefaultHashBuilder, Equivalent, Lifecycle, MemoryUsed, UnitWeighter, Weighter, }; -pub use crate::sync_placeholder::{GuardResult, JoinFuture, PlaceholderGuard}; +use crate::shard::EntryOrPlaceholder; +pub use crate::sync_placeholder::{EntryAction, EntryResult, GuardResult, PlaceholderGuard}; +use crate::sync_placeholder::{JoinFuture, JoinResult}; /// A concurrent cache /// @@ -28,7 +30,7 @@ pub use crate::sync_placeholder::{GuardResult, JoinFuture, PlaceholderGuard}; /// `Arc>` or `Arc>` can also be used. /// /// # Thread Safety and Concurrency -/// The cache instance can wrapped with an `Arc` (or equivalent) and shared between threads. +/// The cache instance can be wrapped with an `Arc` (or equivalent) and shared between threads. /// All methods are accessible via non-mut references so no further synchronization (e.g. Mutex) is needed. pub struct Cache< Key, @@ -226,7 +228,7 @@ impl< self.shards.get(shard_idx).map(|s| (s, hash)) } - /// Reserver additional space for `additional` entries. + /// Reserve additional space for `additional` entries. /// Note that this is counted in entries, and is not weighted. pub fn reserve(&self, additional: usize) { let additional_per_shard = @@ -236,7 +238,7 @@ impl< } } - /// Check if a key exist in the cache. + /// Checks if a key exists in the cache. pub fn contains_key(&self, key: &Q) -> bool where Q: Hash + Equivalent + ?Sized, @@ -403,10 +405,10 @@ impl< /// Gets an item from the cache with key `key` . /// - /// If the corresponding value isn't present in the cache, this functions returns a guard + /// If the corresponding value isn't present in the cache, this function returns a guard /// that can be used to insert the value once it's computed. /// While the returned guard is alive, other calls with the same key using the - /// `get_value_guard` or `get_or_insert` family of functions will wait until the guard + /// `get_value_or_guard` or `get_or_insert` family of functions will wait until the guard /// is dropped or the value is inserted. /// /// A `None` `timeout` means waiting forever. @@ -451,10 +453,10 @@ impl< /// Gets an item from the cache with key `key`. /// - /// If the corresponding value isn't present in the cache, this functions returns a guard + /// If the corresponding value isn't present in the cache, this function returns a guard /// that can be used to insert the value once it's computed. /// While the returned guard is alive, other calls with the same key using the - /// `get_value_guard` or `get_or_insert` family of functions will wait until the guard + /// `get_value_or_guard` or `get_or_insert` family of functions will wait until the guard /// is dropped or the value is inserted. pub async fn get_value_or_guard_async<'a, Q>( &'a self, @@ -464,10 +466,20 @@ impl< Q: Hash + Equivalent + ToOwned + ?Sized, { let (shard, hash) = self.shard_for(key).unwrap(); - if let Some(v) = shard.read().get(hash, key) { - return Ok(v.clone()); + loop { + if let Some(v) = shard.read().get(hash, key) { + return Ok(v.clone()); + } + match JoinFuture::new(&self.lifecycle, shard, hash, key).await { + JoinResult::Filled(Some(shared)) => { + // SAFETY: Filled means the value was set by the loader. + return Ok(unsafe { shared.value().unwrap_unchecked().clone() }); + } + JoinResult::Filled(None) => continue, + JoinResult::Guard(g) => return Err(g), + JoinResult::Timeout => unsafe { unreachable_unchecked() }, + } } - JoinFuture::new(&self.lifecycle, shard, hash, key).await } /// Gets or inserts an item in the cache with key `key`. @@ -489,6 +501,166 @@ impl< } } + /// Atomically accesses an existing entry, or gets a guard for insertion. + /// + /// If a value exists for `key`, `on_occupied` is called with a mutable reference + /// to the key and value. The callback returns an [`EntryAction`] to decide what to do: + /// - [`EntryAction::Retain`]`(T)` — keep the entry, return `T`. + /// Weight is recalculated after the callback returns. + /// - [`EntryAction::Remove`] — remove the entry from the cache. + /// - [`EntryAction::ReplaceWithGuard`] — remove the entry and get a guard for re-insertion. + /// + /// If no value exists, a [`PlaceholderGuard`] is returned for inserting a new value. + /// If another thread is already loading this key, waits up to `timeout` for the value + /// to arrive, then calls `on_occupied` on the result. + /// + /// A `None` `timeout` means waiting forever. + /// A `Some()` timeout will return a Timeout immediately if a guard is alive elsewhere. + /// + /// The callback is `FnOnce` and runs **at most once**. + /// + /// # Performance + /// + /// Always acquires a **write lock** on the shard. For read-only lookups where + /// contention matters, prefer [`get`](Self::get), [`get_value_or_guard`](Self::get_value_or_guard) + /// or similar. + /// + /// The callback runs under the shard write lock — keep it short to avoid blocking + /// other operations on the same shard. **Do not** call back into the cache from the + /// callback, as this will deadlock when the same shard is accessed. + /// + /// # Panics + /// + /// If the callback panics, weight accounting is automatically corrected. + /// However, any partial mutation to the value will remain. + /// + /// # Examples + /// + /// ``` + /// use quick_cache::sync::{Cache, EntryAction, EntryResult}; + /// + /// let cache: Cache = Cache::new(5); + /// cache.insert("counter".to_string(), 0); + /// + /// // Mutate in place: increment a counter + /// let result = cache.entry("counter", None, |_k, v| { + /// *v += 1; + /// EntryAction::Retain(*v) + /// }); + /// assert!(matches!(result, EntryResult::Retained(1))); + /// assert_eq!(cache.get("counter"), Some(1)); + /// ``` + pub fn entry( + &self, + key: &Q, + timeout: Option, + on_occupied: impl FnOnce(&Key, &mut Val) -> EntryAction, + ) -> EntryResult<'_, Key, Val, We, B, L, T> + where + Q: Hash + Equivalent + ToOwned + ?Sized, + { + let (shard, hash) = self.shard_for(key).unwrap(); + // Wrap FnOnce in Option so we can pass &mut FnMut to entry_or_placeholder + // in a loop. The loop retries only on ExistingPlaceholder (another thread is + // loading), which does not invoke the callback — so the Option is still Some + // on retry and the callback runs at most once. + let mut on_occupied = Some(on_occupied); + let mut callback = |k: &Key, v: &mut Val| on_occupied.take().unwrap()(k, v); + let mut deadline = timeout.map(Ok); + + loop { + let mut shard_guard = shard.write(); + match shard_guard.entry_or_placeholder(hash, key, &mut callback) { + EntryOrPlaceholder::Kept(t) => return EntryResult::Retained(t), + EntryOrPlaceholder::Removed(k, v) => return EntryResult::Removed(k, v), + EntryOrPlaceholder::Replaced(shared, old_val) => { + drop(shard_guard); + return EntryResult::Replaced( + PlaceholderGuard::start_loading(&self.lifecycle, shard, shared), + old_val, + ); + } + EntryOrPlaceholder::NewPlaceholder(shared) => { + drop(shard_guard); + return EntryResult::Vacant(PlaceholderGuard::start_loading( + &self.lifecycle, + shard, + shared, + )); + } + EntryOrPlaceholder::ExistingPlaceholder(shared) => { + match PlaceholderGuard::wait_for_placeholder( + &self.lifecycle, + shard, + shard_guard, + shared, + deadline.as_mut(), + ) { + JoinResult::Filled(_) => continue, + JoinResult::Guard(g) => return EntryResult::Vacant(g), + JoinResult::Timeout => return EntryResult::Timeout, + } + } + } + } + } + + /// Async version of [`Self::entry`]. + /// + /// Atomically accesses an existing entry, or gets a guard for insertion. + /// If another task is already loading this key, waits asynchronously for the value. + /// + /// See [`entry`](Self::entry) for full documentation. + pub async fn entry_async<'a, Q, T>( + &'a self, + key: &Q, + on_occupied: impl FnOnce(&Key, &mut Val) -> EntryAction, + ) -> EntryResult<'a, Key, Val, We, B, L, T> + where + Q: Hash + Equivalent + ToOwned + ?Sized, + { + let (shard, hash) = self.shard_for(key).unwrap(); + // See entry() for explanation of the Option::take pattern. + let mut on_occupied = Some(on_occupied); + let mut callback = |k: &Key, v: &mut Val| on_occupied.take().unwrap()(k, v); + + loop { + // Scope the write guard so it doesn't appear in the async state machine, + // which would make the future !Send. + let result = { + let mut shard_guard = shard.write(); + match shard_guard.entry_or_placeholder(hash, key, &mut callback) { + EntryOrPlaceholder::Kept(t) => Ok(EntryResult::Retained(t)), + EntryOrPlaceholder::Removed(k, v) => Ok(EntryResult::Removed(k, v)), + EntryOrPlaceholder::Replaced(shared, old_val) => { + drop(shard_guard); + Ok(EntryResult::Replaced( + PlaceholderGuard::start_loading(&self.lifecycle, shard, shared), + old_val, + )) + } + EntryOrPlaceholder::NewPlaceholder(shared) => { + drop(shard_guard); + Ok(EntryResult::Vacant(PlaceholderGuard::start_loading( + &self.lifecycle, + shard, + shared, + ))) + } + EntryOrPlaceholder::ExistingPlaceholder(_) => Err(()), + } + }; + match result { + Ok(entry_result) => return entry_result, + Err(()) => match JoinFuture::new(&self.lifecycle, shard, hash, key).await { + JoinResult::Filled(_) => continue, + JoinResult::Guard(g) => return EntryResult::Vacant(g), + JoinResult::Timeout => unsafe { unreachable_unchecked() }, + }, + } + } + } + /// Get total memory used by cache data structures /// /// It should be noted that if cache key or value is some type like `Vec`, @@ -793,4 +965,538 @@ mod tests { let not_found = cache.remove_if(&999, |_| true); assert_eq!(not_found, None); } + + /// Tests all basic entry actions: Retain, Remove, ReplaceWithGuard, Vacant, mutate+Retain + #[test] + fn test_entry_actions() { + let cache = Cache::new(100); + cache.insert(1, 10); + cache.insert(2, 20); + + // Retain returns the value via callback, entry stays + let result = cache.entry(&1, None, |_k, v| EntryAction::Retain(*v)); + assert!(matches!(result, EntryResult::Retained(10))); + assert_eq!(cache.get(&1), Some(10)); + + // Mutate in place via Retain + let result = cache.entry(&1, None, |_k, v| { + *v += 5; + EntryAction::Retain(()) + }); + assert!(matches!(result, EntryResult::Retained(()))); + assert_eq!(cache.get(&1), Some(15)); + + // Remove + let result = cache.entry(&1, None, |_k, _v| EntryAction::<()>::Remove); + assert!(matches!(result, EntryResult::Removed(1, 15))); + assert_eq!(cache.get(&1), None); + + // Remove then re-enter same key → Vacant + let result = cache.entry(&1, None, |_k, v| EntryAction::Retain(*v)); + match result { + EntryResult::Vacant(g) => { + let _ = g.insert(99); + assert_eq!(cache.get(&1), Some(99)); + } + _ => panic!("expected Vacant for removed key"), + } + + // ReplaceWithGuard: capture old value, get guard, insert new + let mut old_val = 0; + let result = cache.entry(&2, None, |_k, v| { + old_val = *v; + EntryAction::<()>::ReplaceWithGuard + }); + assert_eq!(old_val, 20); + match result { + EntryResult::Replaced(g, old) => { + assert_eq!(old, 20); + let _ = g.insert(old_val + 100); + assert_eq!(cache.get(&2), Some(120)); + } + _ => panic!("expected Replaced"), + } + + // ReplaceWithGuard then abandon guard → entry gone + let result = cache.entry(&2, None, |_k, _v| EntryAction::<()>::ReplaceWithGuard); + match result { + EntryResult::Replaced(g, _old) => { + drop(g); + assert_eq!(cache.get(&2), None); + } + _ => panic!("expected Replaced"), + } + + // Vacant key → guard + let result = cache.entry(&3, None, |_k, v| EntryAction::Retain(*v)); + match result { + EntryResult::Vacant(g) => { + let _ = g.insert(30); + assert_eq!(cache.get(&3), Some(30)); + } + _ => panic!("expected Vacant"), + } + } + + /// Tests weight tracking across all entry actions using a string-length weighter + #[test] + fn test_entry_weight_tracking() { + #[derive(Clone)] + struct StringWeighter; + impl crate::Weighter for StringWeighter { + fn weight(&self, _key: &u64, val: &String) -> u64 { + val.len() as u64 + } + } + + let cache = Cache::with_weighter(100, 100_000, StringWeighter); + cache.insert(1, "hello".to_string()); + cache.insert(2, "world".to_string()); + assert_eq!(cache.weight(), 10); + + // Retain without mutation — weight unchanged + let result = cache.entry(&1, None, |_k, _v| EntryAction::Retain(())); + assert!(matches!(result, EntryResult::Retained(()))); + assert_eq!(cache.weight(), 10); + + // Mutate to longer string — weight increases + let result = cache.entry(&1, None, |_k, v| { + v.push_str(" world"); + EntryAction::Retain(()) + }); + assert!(matches!(result, EntryResult::Retained(()))); + assert_eq!(cache.weight(), 16); // "hello world" (11) + "world" (5) + assert_eq!(cache.get(&1).unwrap(), "hello world"); + + // Mutate to empty string — weight to zero, entry stays + let result = cache.entry(&1, None, |_k, v| { + v.clear(); + EntryAction::Retain(()) + }); + assert!(matches!(result, EntryResult::Retained(()))); + assert_eq!(cache.weight(), 5); // "" (0) + "world" (5) + assert_eq!(cache.get(&1).unwrap(), ""); + + // Remove — weight decremented + let result = cache.entry(&2, None, |_k, _v| EntryAction::<()>::Remove); + assert!(matches!(result, EntryResult::Removed(2, _))); + assert_eq!(cache.weight(), 0); + assert_eq!(cache.len(), 1); + + // ReplaceWithGuard — old weight gone, new weight after insert + cache.insert(3, "hello".to_string()); + assert_eq!(cache.weight(), 5); + let result = cache.entry(&3, None, |_k, _v| EntryAction::<()>::ReplaceWithGuard); + match result { + EntryResult::Replaced(g, _old) => { + assert_eq!(cache.weight(), 0); + let _ = g.insert("hello world!!".to_string()); + assert_eq!(cache.weight(), 13); + } + _ => panic!("expected Replaced"), + } + } + + /// Tests eviction and zero-capacity edge cases + #[test] + fn test_entry_eviction() { + // Cache with capacity for ~2 items — insert 3rd triggers eviction + let cache = Cache::new(2); + cache.insert(1, 10); + cache.insert(2, 20); + assert_eq!(cache.len(), 2); + + let result = cache.entry(&3, None, |_k, v| EntryAction::Retain(*v)); + match result { + EntryResult::Vacant(g) => { + let _ = g.insert(30); + assert!(cache.len() <= 2); + assert_eq!(cache.get(&3), Some(30)); + } + _ => panic!("expected Vacant"), + } + + // Zero-capacity cache — insert evicts immediately + let cache = Cache::new(0); + let result = cache.entry(&1, None, |_k, v| EntryAction::Retain(*v)); + match result { + EntryResult::Vacant(g) => { + let _ = g.insert(10); + assert_eq!(cache.get(&1), None); + } + _ => panic!("expected Vacant"), + } + } + + /// Tests entry() waiting on existing placeholder: value arrives, guard abandoned + #[test] + #[cfg_attr(miri, ignore)] + fn test_entry_concurrent_placeholder_wait() { + let cache = Arc::new(Cache::new(100)); + let barrier = Arc::new(Barrier::new(2)); + + // Thread holds guard, inserts after delay + let cache2 = cache.clone(); + let barrier2 = barrier.clone(); + let handle = thread::spawn(move || match cache2.get_value_or_guard(&1, None) { + GuardResult::Guard(g) => { + barrier2.wait(); + std::thread::sleep(Duration::from_millis(50)); + let _ = g.insert(42); + } + _ => panic!("expected guard"), + }); + + barrier.wait(); + let result = cache.entry(&1, None, |_k, v| EntryAction::Retain(*v)); + assert!(matches!(result, EntryResult::Retained(42))); + handle.join().unwrap(); + } + + /// Tests entry() getting guard when placeholder loader abandons + #[test] + #[cfg_attr(miri, ignore)] + fn test_entry_concurrent_placeholder_guard_abandoned() { + let cache = Arc::new(Cache::new(100)); + let barrier = Arc::new(Barrier::new(2)); + + let cache2 = cache.clone(); + let barrier2 = barrier.clone(); + let handle = thread::spawn(move || match cache2.get_value_or_guard(&1, None) { + GuardResult::Guard(g) => { + barrier2.wait(); + std::thread::sleep(Duration::from_millis(50)); + drop(g); + } + _ => panic!("expected guard"), + }); + + barrier.wait(); + let result = cache.entry(&1, None, |_k, v| EntryAction::Retain(*v)); + match result { + EntryResult::Vacant(g) => { + let _ = g.insert(99); + assert_eq!(cache.get(&1), Some(99)); + } + _ => panic!("expected Vacant after abandoned placeholder"), + } + handle.join().unwrap(); + } + + /// Tests zero and nonzero timeouts + #[test] + #[cfg_attr(miri, ignore)] + fn test_entry_timeout() { + let cache = Cache::new(100); + + // Zero timeout — immediate Timeout when placeholder exists + let guard = match cache.get_value_or_guard(&1, None) { + GuardResult::Guard(g) => g, + _ => panic!("expected guard"), + }; + let result = cache.entry(&1, Some(Duration::ZERO), |_k, v| EntryAction::Retain(*v)); + assert!(matches!(result, EntryResult::Timeout)); + let _ = guard.insert(1); + + // Nonzero timeout — guard held longer than timeout + let cache = Arc::new(Cache::new(100)); + let barrier = Arc::new(Barrier::new(2)); + let cache2 = cache.clone(); + let barrier2 = barrier.clone(); + let holder = thread::spawn(move || { + let guard = match cache2.get_value_or_guard(&1, None) { + GuardResult::Guard(g) => g, + _ => panic!("expected guard"), + }; + barrier2.wait(); + std::thread::sleep(Duration::from_millis(200)); + let _ = guard.insert(1); + }); + + barrier.wait(); + let result = cache.entry(&1, Some(Duration::from_millis(50)), |_k, v| { + EntryAction::Retain(*v) + }); + assert!(matches!(result, EntryResult::Timeout)); + holder.join().unwrap(); + } + + /// Tests multiple waiters all receiving the value + #[test] + #[cfg_attr(miri, ignore)] + fn test_entry_concurrent_multiple_waiters() { + let cache = Arc::new(Cache::new(100)); + let barrier = Arc::new(Barrier::new(4)); // 1 loader + 3 waiters + + let cache1 = cache.clone(); + let barrier1 = barrier.clone(); + let loader = thread::spawn(move || match cache1.get_value_or_guard(&1, None) { + GuardResult::Guard(g) => { + barrier1.wait(); + std::thread::sleep(Duration::from_millis(50)); + let _ = g.insert(42); + } + _ => panic!("expected guard"), + }); + + let mut waiters = Vec::new(); + for _ in 0..3 { + let cache_c = cache.clone(); + let barrier_c = barrier.clone(); + waiters.push(thread::spawn(move || { + barrier_c.wait(); + let result = cache_c.entry(&1, None, |_k, v| EntryAction::Retain(*v)); + match result { + EntryResult::Retained(v) => v, + _ => panic!("expected Value"), + } + })); + } + + loader.join().unwrap(); + for w in waiters { + assert_eq!(w.join().unwrap(), 42); + } + } + + /// Tests ReplaceWithGuard and Remove actions after waiting for a placeholder + #[test] + #[cfg_attr(miri, ignore)] + fn test_entry_concurrent_action_after_wait() { + // ReplaceWithGuard after wait + let cache = Arc::new(Cache::new(100)); + let barrier = Arc::new(Barrier::new(2)); + + let cache1 = cache.clone(); + let barrier1 = barrier.clone(); + let loader = thread::spawn(move || match cache1.get_value_or_guard(&1, None) { + GuardResult::Guard(g) => { + barrier1.wait(); + std::thread::sleep(Duration::from_millis(50)); + let _ = g.insert(42); + } + _ => panic!("expected guard"), + }); + + barrier.wait(); + let result = cache.entry(&1, None, |_k, _v| EntryAction::<()>::ReplaceWithGuard); + match result { + EntryResult::Replaced(g, old) => { + assert_eq!(old, 42); + let _ = g.insert(100); + assert_eq!(cache.get(&1), Some(100)); + } + _ => panic!("expected Replaced"), + } + loader.join().unwrap(); + + // Remove after wait + let cache = Arc::new(Cache::new(100)); + let barrier = Arc::new(Barrier::new(2)); + + let cache1 = cache.clone(); + let barrier1 = barrier.clone(); + let loader = thread::spawn(move || match cache1.get_value_or_guard(&1, None) { + GuardResult::Guard(g) => { + barrier1.wait(); + std::thread::sleep(Duration::from_millis(50)); + let _ = g.insert(42); + } + _ => panic!("expected guard"), + }); + + barrier.wait(); + let result = cache.entry(&1, None, |_k, _v| EntryAction::<()>::Remove); + assert!(matches!(result, EntryResult::Removed(1, 42))); + assert_eq!(cache.get(&1), None); + loader.join().unwrap(); + } + + /// Multi-thread stress test for entry() + #[test] + #[cfg_attr(miri, ignore)] + fn test_entry_concurrent_stress() { + const N_THREADS: usize = 8; + const N_KEYS: usize = 50; + const N_OPS: usize = 500; + + let cache = Arc::new(Cache::new(1000)); + let barrier = Arc::new(Barrier::new(N_THREADS)); + + let mut handles = Vec::new(); + for t in 0..N_THREADS { + let cache = cache.clone(); + let barrier = barrier.clone(); + handles.push(thread::spawn(move || { + barrier.wait(); + for i in 0..N_OPS { + let key = (t * N_OPS + i) % N_KEYS; + let result = cache.entry(&key, Some(Duration::from_millis(10)), |_k, v| { + EntryAction::Retain(*v) + }); + match result { + EntryResult::Retained(_) => {} + EntryResult::Vacant(g) => { + let _ = g.insert(key * 10); + } + EntryResult::Replaced(g, _) => { + let _ = g.insert(key * 10); + } + EntryResult::Timeout => {} + EntryResult::Removed(_, _) => {} + } + } + })); + } + + for h in handles { + h.join().unwrap(); + } + + assert!(cache.len() <= N_KEYS); + for key in 0..N_KEYS { + if let Some(v) = cache.get(&key) { + assert_eq!(v, key * 10); + } + } + } + + // --- Async tests --- + + /// Tests all basic async entry actions in one test + #[tokio::test] + async fn test_entry_async_actions() { + let cache = Cache::new(100); + cache.insert(1, 10); + cache.insert(2, 20); + + // Retain + let result = cache.entry_async(&1, |_k, v| EntryAction::Retain(*v)).await; + assert!(matches!(result, EntryResult::Retained(10))); + assert_eq!(cache.get(&1), Some(10)); + + // Remove + let result = cache + .entry_async(&1, |_k, _v| EntryAction::<()>::Remove) + .await; + assert!(matches!(result, EntryResult::Removed(1, 10))); + assert_eq!(cache.get(&1), None); + + // ReplaceWithGuard + let result = cache + .entry_async(&2, |_k, _v| EntryAction::<()>::ReplaceWithGuard) + .await; + match result { + EntryResult::Replaced(g, old) => { + assert_eq!(old, 20); + let _ = g.insert(42); + assert_eq!(cache.get(&2), Some(42)); + } + _ => panic!("expected Replaced"), + } + + // Vacant + let result = cache.entry_async(&3, |_k, v| EntryAction::Retain(*v)).await; + match result { + EntryResult::Vacant(g) => { + let _ = g.insert(99); + assert_eq!(cache.get(&3), Some(99)); + } + _ => panic!("expected Vacant"), + } + } + + /// Tests async entry waiting on placeholder: value arrives, guard abandoned + #[tokio::test(flavor = "multi_thread")] + async fn test_entry_async_concurrent_wait() { + let cache = Arc::new(Cache::new(100)); + let barrier = Arc::new(Barrier::new(2)); + + let cache1 = cache.clone(); + let barrier1 = barrier.clone(); + let holder = thread::spawn(move || { + let guard = match cache1.get_value_or_guard(&1, None) { + GuardResult::Guard(g) => g, + _ => panic!("expected guard"), + }; + barrier1.wait(); + std::thread::sleep(Duration::from_millis(50)); + let _ = guard.insert(42); + }); + + barrier.wait(); + let result = cache.entry_async(&1, |_k, v| EntryAction::Retain(*v)).await; + assert!(matches!(result, EntryResult::Retained(42))); + holder.join().unwrap(); + } + + /// Tests async entry getting guard when placeholder loader abandons + #[tokio::test(flavor = "multi_thread")] + async fn test_entry_async_concurrent_guard_abandoned() { + let cache = Arc::new(Cache::new(100)); + let barrier = Arc::new(Barrier::new(2)); + + let cache1 = cache.clone(); + let barrier1 = barrier.clone(); + let holder = thread::spawn(move || { + let guard = match cache1.get_value_or_guard(&1, None) { + GuardResult::Guard(g) => g, + _ => panic!("expected guard"), + }; + barrier1.wait(); + std::thread::sleep(Duration::from_millis(50)); + drop(guard); + }); + + barrier.wait(); + let result = cache.entry_async(&1, |_k, v| EntryAction::Retain(*v)).await; + match result { + EntryResult::Vacant(g) => { + let _ = g.insert(99); + } + _ => panic!("expected Vacant after abandoned placeholder"), + } + assert_eq!(cache.get(&1), Some(99)); + holder.join().unwrap(); + } + + /// Multi-task async stress test + #[tokio::test(flavor = "multi_thread")] + #[cfg_attr(miri, ignore)] + async fn test_entry_async_concurrent_stress() { + const N_TASKS: usize = 16; + const N_KEYS: usize = 50; + const N_OPS: usize = 200; + + let cache = Arc::new(Cache::new(1000)); + let barrier = Arc::new(tokio::sync::Barrier::new(N_TASKS)); + + let mut handles = Vec::new(); + for t in 0..N_TASKS { + let cache = cache.clone(); + let barrier = barrier.clone(); + handles.push(tokio::spawn(async move { + barrier.wait().await; + for i in 0..N_OPS { + let key = (t * N_OPS + i) % N_KEYS; + // Use get_or_insert_async instead of entry_async to avoid + // lifetime issues with tokio::spawn (entry_async borrows &self) + let _ = cache + .get_or_insert_async(&key, async { Ok::<_, ()>(key * 10) }) + .await; + } + })); + } + + for h in handles { + h.await.unwrap(); + } + + assert!(cache.len() <= N_KEYS); + for key in 0..N_KEYS { + if let Some(v) = cache.get(&key) { + assert_eq!(v, key * 10); + } + } + } } diff --git a/src/sync_placeholder.rs b/src/sync_placeholder.rs index c3dc521..8f09d11 100644 --- a/src/sync_placeholder.rs +++ b/src/sync_placeholder.rs @@ -2,6 +2,7 @@ use std::{ future::Future, hash::{BuildHasher, Hash}, hint::unreachable_unchecked, + marker::PhantomPinned, mem, pin, task::{self, Poll}, time::{Duration, Instant}, @@ -60,6 +61,14 @@ pub struct Placeholder { value: OnceLock, } +impl Placeholder { + /// Returns the filled value, if any. + #[inline] + pub(crate) fn value(&self) -> Option<&Val> { + self.value.get() + } +} + #[derive(Debug)] pub struct State { /// The waiters list @@ -126,14 +135,59 @@ impl Waiter { } } +/// Result of [`Cache::get_value_or_guard`](crate::sync::Cache::get_value_or_guard). +/// +/// See also [`Cache::get_value_or_guard_async`](crate::sync::Cache::get_value_or_guard_async) +/// which returns `Result` instead. #[derive(Debug)] pub enum GuardResult<'a, Key, Val, We, B, L> { + /// The value was found in the cache. Value(Val), + /// The key was absent; use the guard to insert a value. Guard(PlaceholderGuard<'a, Key, Val, We, B, L>), + /// Timed out waiting for another loader's placeholder. + Timeout, +} + +// Re-export from shard where it's defined. +pub use crate::shard::EntryAction; + +/// Result of waiting for a placeholder or [`JoinFuture`]. +pub(crate) enum JoinResult<'a, Key, Val, We, B, L> { + /// Value is available — either found directly in the cache (`None`) or + /// inside the shared placeholder (`Some`). + Filled(Option>), + /// Got the guard — caller should load the value. + Guard(PlaceholderGuard<'a, Key, Val, We, B, L>), + /// Timed out waiting (sync paths only). + Timeout, +} + +/// Result of an [`entry`](crate::sync::Cache::entry) or +/// [`entry_async`](crate::sync::Cache::entry_async) operation. +#[derive(Debug)] +pub enum EntryResult<'a, Key, Val, We, B, L, T> { + /// The key existed and the callback returned [`EntryAction::Retain`]. + /// Contains the value `T` returned by the callback. + Retained(T), + /// The key existed and the callback returned [`EntryAction::Remove`]. + /// Contains the removed key and value. + Removed(Key, Val), + /// The key existed and the callback returned [`EntryAction::ReplaceWithGuard`]. + /// Contains a [`PlaceholderGuard`] for re-insertion and the old value. + Replaced(PlaceholderGuard<'a, Key, Val, We, B, L>, Val), + /// The key was absent. Contains a [`PlaceholderGuard`] for inserting a new value. + Vacant(PlaceholderGuard<'a, Key, Val, We, B, L>), + /// Timed out waiting for another loader's placeholder. + /// + /// Only returned by [`Cache::entry`](crate::sync::Cache::entry), + /// which accepts a `timeout` parameter. For the async variant, use an external + /// timeout mechanism (e.g. `tokio::time::timeout`). Timeout, } impl<'a, Key, Val, We, B, L> PlaceholderGuard<'a, Key, Val, We, B, L> { + #[inline] pub fn start_loading( lifecycle: &'a L, shard: &'a RwLock>>, @@ -158,14 +212,11 @@ impl<'a, Key, Val, We, B, L> PlaceholderGuard<'a, Key, Val, We, B, L> { lifecycle: &'a L, shard: &'a RwLock>>, shared: SharedPlaceholder, - ) -> Result> - where - Val: Clone, - { + ) -> Result, PlaceholderGuard<'a, Key, Val, We, B, L>> { // Check if the value was loaded, and if it wasn't it means we got the // guard and need to start loading the value. - if let Some(v) = shared.value.get() { - Ok(v.clone()) + if shared.value().is_some() { + Ok(shared) } else { Err(PlaceholderGuard::start_loading(lifecycle, shard, shared)) } @@ -179,10 +230,7 @@ impl<'a, Key, Val, We, B, L> PlaceholderGuard<'a, Key, Val, We, B, L> { shared: &SharedPlaceholder, // a function that returns a waiter if it should be added waiter_new: impl FnOnce() -> Option, - ) -> Option - where - Val: Clone, - { + ) -> bool { let mut state = shared.state.write(); // _locked_shard could be released here, it would be sufficient to synchronize with the holder // of the guard trying to remove the placeholder from the cache. But if this placeholder is hot, @@ -193,13 +241,9 @@ impl<'a, Key, Val, We, B, L> PlaceholderGuard<'a, Key, Val, We, B, L> { if let Some(waiter) = waiter_new() { state.waiters.push(waiter); } - None + false } - LoadingState::Inserted => unsafe { - // SAFETY: The value is guaranteed to be set at this point - drop(state); // Allow cloning outside the lock - Some(shared.value.get().unwrap_unchecked().clone()) - }, + LoadingState::Inserted => true, } } } @@ -218,59 +262,94 @@ impl< shard: &'a RwLock>>, hash: u64, key: &Q, - mut timeout: Option, + timeout: Option, ) -> GuardResult<'a, Key, Val, We, B, L> where Q: Hash + Equivalent + ToOwned + ?Sized, { let mut shard_guard = shard.write(); - let shared = match shard_guard.upsert_placeholder(hash, key) { + let shared = match shard_guard.get_or_placeholder(hash, key) { Ok((_, v)) => return GuardResult::Value(v.clone()), Err((shared, true)) => { return GuardResult::Guard(Self::start_loading(lifecycle, shard, shared)); } Err((shared, false)) => shared, }; + let mut deadline = timeout.map(Ok); + match Self::wait_for_placeholder(lifecycle, shard, shard_guard, shared, deadline.as_mut()) { + JoinResult::Filled(shared) => unsafe { + // SAFETY: Filled means the value was set by the loader. + GuardResult::Value(shared.unwrap_unchecked().value().unwrap_unchecked().clone()) + }, + JoinResult::Guard(g) => GuardResult::Guard(g), + JoinResult::Timeout => GuardResult::Timeout, + } + } - // Create notified flag on stack - this will live for the entire duration of join + /// Waits for an existing placeholder to be filled by another thread. + /// + /// Registers the current thread as a waiter (consuming the shard guard to avoid + /// races with placeholder removal), then parks until notified or timeout. + /// + /// `deadline` is `None` for no timeout, or `Some(&mut Ok(duration))` on the first + /// call. On first use the duration is converted in-place to `Err(instant)` so that + /// callers that retry (e.g. `entry`) preserve the original deadline across calls. + pub(crate) fn wait_for_placeholder( + lifecycle: &'a L, + shard: &'a RwLock>>, + shard_guard: RwLockWriteGuard<'a, CacheShard>>, + shared: SharedPlaceholder, + deadline: Option<&mut Result>, + ) -> JoinResult<'a, Key, Val, We, B, L> { let notified = pin::pin!(AtomicBool::new(false)); - // Set if the thread was added to the waiters list let mut parked_thread = None; - let maybe_val = Self::join_waiters(shard_guard, &shared, || { - if timeout.is_some_and(|t| t.is_zero()) { + let already_filled = Self::join_waiters(shard_guard, &shared, || { + // Skip registering a waiter if the timeout is zero. + // An already-elapsed Err(instant) deadline is not checked here; + // the loop below handles it and join_timeout cleans up the waiter. + if matches!(deadline.as_deref(), Some(Ok(d)) if d.is_zero()) { None } else { let thread = thread::current(); - let id = thread.id(); - parked_thread = Some(id); + parked_thread = Some(thread.id()); Some(Waiter::Thread { thread, notified: &*notified as *const AtomicBool, }) } }); - if let Some(v) = maybe_val { - return GuardResult::Value(v); + if already_filled { + return JoinResult::Filled(Some(shared)); } - // Track the start time of the timeout, set lazily - let mut timeout_start = None; + // Lazily convert the duration to a deadline on first call; + // subsequent retries from entry() reuse the same deadline. + let deadline = deadline.and_then(|d| match *d { + Ok(dur) => match Instant::now().checked_add(dur) { + Some(instant) => { + *d = Err(instant); + Some(instant) + } + None => None, // overflow → treat as no timeout (wait forever) + }, + Err(instant) => Some(instant), + }); loop { - if let Some(remaining) = timeout { + if let Some(instant) = deadline { + let remaining = instant.saturating_duration_since(Instant::now()); if remaining.is_zero() { return Self::join_timeout(lifecycle, shard, shared, parked_thread, ¬ified); } - let start = *timeout_start.get_or_insert_with(Instant::now); #[cfg(not(fuzzing))] thread::park_timeout(remaining); - timeout = Some(remaining.saturating_sub(start.elapsed())); } else { + #[cfg(not(fuzzing))] thread::park(); } if notified.load(Ordering::Acquire) { return match Self::handle_notification(lifecycle, shard, shared) { - Ok(v) => GuardResult::Value(v), - Err(g) => GuardResult::Guard(g), + Ok(shared) => JoinResult::Filled(Some(shared)), + Err(g) => JoinResult::Guard(g), }; } } @@ -284,12 +363,12 @@ impl< // when timeout is zero, the thread may have not been added to the waiters list parked_thread: Option, notified: &AtomicBool, - ) -> GuardResult<'a, Key, Val, We, B, L> { + ) -> JoinResult<'a, Key, Val, We, B, L> { let mut state = shared.state.write(); match state.loading { LoadingState::Loading if notified.load(Ordering::Acquire) => { drop(state); // Drop state guard to avoid a deadlock with start_loading - GuardResult::Guard(PlaceholderGuard::start_loading(lifecycle, shard, shared)) + JoinResult::Guard(PlaceholderGuard::start_loading(lifecycle, shard, shared)) } LoadingState::Loading => { if parked_thread.is_some() { @@ -304,12 +383,12 @@ impl< unsafe { unreachable_unchecked() }; } } - GuardResult::Timeout + JoinResult::Timeout + } + LoadingState::Inserted => { + drop(state); + JoinResult::Filled(Some(shared)) } - LoadingState::Inserted => unsafe { - // SAFETY: The value is guaranteed to be set at this point - GuardResult::Value(shared.value.get().unwrap_unchecked().clone()) - }, } } } @@ -412,19 +491,29 @@ impl std::fmt::Debug for PlaceholderGuard<'_, Key, Val, We, } } -/// Future that results in an Ok(Value) or Err(Guard) -pub struct JoinFuture<'a, 'b, Q: ?Sized, Key, Val, We, B, L> { +/// Future that checks for an existing placeholder and waits for it to be filled. +/// +/// The shard lock is acquired as a local variable inside `poll`, never stored +/// in the future state, so the future remains `Send`. +/// +/// # Pin safety +/// +/// This future is `!Unpin` because `poll` registers `&self.notified` as a raw +/// pointer in the placeholder's waiter list. `Pin` guarantees the future won't +/// be moved after the first poll, keeping that pointer valid. The pointer is +/// cleaned up in `drop_pending_waiter` before the struct is destroyed. +pub(crate) struct JoinFuture<'a, 'b, Q: ?Sized, Key, Val, We, B, L> { lifecycle: &'a L, shard: &'a RwLock>>, - state: JoinFutureState<'b, Q, Val>, + hash: u64, + key: &'b Q, + state: JoinFutureState, notified: AtomicBool, + _pin: PhantomPinned, } -enum JoinFutureState<'b, Q: ?Sized, Val> { - Created { - hash: u64, - key: &'b Q, - }, +enum JoinFutureState { + Created, Pending { shared: SharedPlaceholder, waker: task::Waker, @@ -433,20 +522,25 @@ enum JoinFutureState<'b, Q: ?Sized, Val> { } impl<'a, 'b, Q: ?Sized, Key, Val, We, B, L> JoinFuture<'a, 'b, Q, Key, Val, We, B, L> { - pub fn new( + pub(crate) fn new( lifecycle: &'a L, shard: &'a RwLock>>, hash: u64, key: &'b Q, - ) -> JoinFuture<'a, 'b, Q, Key, Val, We, B, L> { + ) -> Self { Self { lifecycle, shard, - state: JoinFutureState::Created { hash, key }, + hash, + key, + state: JoinFutureState::Created, notified: Default::default(), + _pin: PhantomPinned, } } +} +impl JoinFuture<'_, '_, Q, Key, Val, We, B, L> { #[cold] fn drop_pending_waiter(&mut self) { let JoinFutureState::Pending { shared, .. } = @@ -475,7 +569,7 @@ impl<'a, 'b, Q: ?Sized, Key, Val, We, B, L> JoinFuture<'a, 'b, Q, Key, Val, We, unsafe { unreachable_unchecked() } } } - LoadingState::Inserted => (), // We were notified but didn't get polled - nothing to do + LoadingState::Inserted => (), // Notified but didn't get polled - nothing to do } } } @@ -493,35 +587,41 @@ impl< 'a, Key: Eq + Hash, Q: Hash + Equivalent + ToOwned + ?Sized, - Val: Clone, + Val, We: Weighter, B: BuildHasher, L: Lifecycle, > Future for JoinFuture<'a, '_, Q, Key, Val, We, B, L> { - type Output = Result>; + type Output = JoinResult<'a, Key, Val, We, B, L>; - fn poll(mut self: pin::Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { - let this = &mut *self; + fn poll(self: pin::Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + // SAFETY: We never move the struct out of the Pin — only read/write individual + // fields. The `notified` field's address (registered in the waiter list) stays + // stable because Pin guarantees the future won't be moved. + let this = unsafe { self.get_unchecked_mut() }; let lifecycle = this.lifecycle; let shard = this.shard; match &mut this.state { - JoinFutureState::Created { hash, key } => { - debug_assert!(!this.notified.load(Ordering::Acquire)); + JoinFutureState::Created => { let mut shard_guard = shard.write(); - match shard_guard.upsert_placeholder(*hash, *key) { - Ok((_, v)) => { + match shard_guard.get_or_placeholder(this.hash, this.key) { + Ok(_) => { this.state = JoinFutureState::Done; - Poll::Ready(Ok(v.clone())) + Poll::Ready(JoinResult::Filled(None)) } Err((shared, true)) => { - let guard = PlaceholderGuard::start_loading(lifecycle, shard, shared); this.state = JoinFutureState::Done; - Poll::Ready(Err(guard)) + drop(shard_guard); + Poll::Ready(JoinResult::Guard(PlaceholderGuard::start_loading( + lifecycle, shard, shared, + ))) } Err((shared, false)) => { + // Register as waiter while holding shard lock — prevents + // race with drop_uninserted_slow removing the placeholder. let mut waker = None; - let maybe_val = + let already_filled = PlaceholderGuard::join_waiters(shard_guard, &shared, || { let waker_ = cx.waker().clone(); waker = Some(waker_.clone()); @@ -530,34 +630,29 @@ impl< notified: &this.notified as *const AtomicBool, }) }); - if let Some(v) = maybe_val { - debug_assert!(waker.is_none()); - debug_assert!(!this.notified.load(Ordering::Acquire)); + if already_filled { this.state = JoinFutureState::Done; - Poll::Ready(Ok(v)) + Poll::Ready(JoinResult::Filled(Some(shared))) } else { - let waker = waker.unwrap(); - this.state = JoinFutureState::Pending { shared, waker }; + this.state = JoinFutureState::Pending { + shared, + waker: waker.unwrap(), + }; Poll::Pending } } } } JoinFutureState::Pending { waker, shared } => { - 'notified: { - if this.notified.load(Ordering::Acquire) { - break 'notified; - } - // Update waker in case it changed + if !this.notified.load(Ordering::Acquire) { let new_waker = cx.waker(); - if !waker.will_wake(new_waker) { - let mut state = shared.state.write(); - // Re-check notified after acquiring the lock. A concurrent - // insert may have drained the waiters list between the - // notified check above and this point. - if this.notified.load(Ordering::Acquire) { - break 'notified; - } + if waker.will_wake(new_waker) { + return Poll::Pending; + } + let mut state = shared.state.write(); + // Re-check after acquiring the lock — a concurrent insert + // may have drained the waiters list in the meantime. + if !this.notified.load(Ordering::Acquire) { let w = unsafe { state .waiters @@ -570,17 +665,20 @@ impl< waker: new_waker.clone(), notified: &this.notified as *const AtomicBool, }; + return Poll::Pending; } - return Poll::Pending; - }; + } let JoinFutureState::Pending { shared, .. } = mem::replace(&mut this.state, JoinFutureState::Done) else { unsafe { unreachable_unchecked() } }; - Poll::Ready(PlaceholderGuard::handle_notification( - lifecycle, shard, shared, - )) + Poll::Ready( + match PlaceholderGuard::handle_notification(lifecycle, shard, shared) { + Ok(shared) => JoinResult::Filled(Some(shared)), + Err(g) => JoinResult::Guard(g), + }, + ) } JoinFutureState::Done => panic!("Polled after ready"), } diff --git a/src/unsync.rs b/src/unsync.rs index 8d4e00a..66ee954 100644 --- a/src/unsync.rs +++ b/src/unsync.rs @@ -135,13 +135,13 @@ impl, B: BuildHasher, L: Lifecycle(&self, key: &Q) -> bool where Q: Hash + Equivalent + ?Sized, @@ -261,7 +261,7 @@ impl, B: BuildHasher, L: Lifecycle + ToOwned + ?Sized, { - let idx = match self.shard.upsert_placeholder(self.shard.hash(key), key) { + let idx = match self.shard.get_or_placeholder(self.shard.hash(key), key) { Ok((idx, _)) => idx, Err((plh, _)) => { let v = with()?; @@ -287,7 +287,7 @@ impl, B: BuildHasher, L: Lifecycle + ToOwned + ?Sized, { - let idx = match self.shard.upsert_placeholder(self.shard.hash(key), key) { + let idx = match self.shard.get_or_placeholder(self.shard.hash(key), key) { Ok((idx, _)) => idx, Err((plh, _)) => { let v = with()?; @@ -302,14 +302,14 @@ impl, B: BuildHasher, L: Lifecycle(&mut self, key: &Q) -> Result<&Val, Guard<'_, Key, Val, We, B, L>> where Q: Hash + Equivalent + ToOwned + ?Sized, { // TODO: this could be using a simpler entry API - match self.shard.upsert_placeholder(self.shard.hash(key), key) { + match self.shard.get_or_placeholder(self.shard.hash(key), key) { Ok((_, v)) => unsafe { // Rustc gets insanely confused about returning from mut borrows // Safety: v has the same lifetime as self @@ -325,7 +325,7 @@ impl, B: BuildHasher, L: Lifecycle, B: BuildHasher, L: Lifecycle + ToOwned + ?Sized, { // TODO: this could be using a simpler entry API - match self.shard.upsert_placeholder(self.shard.hash(key), key) { + match self.shard.get_or_placeholder(self.shard.hash(key), key) { Ok((idx, _)) => Ok(self.shard.peek_token_mut(idx).map(RefMut)), Err((placeholder, _)) => Err(Guard { cache: self,