From 103bfb8a94787b51c5a4e802e69d27dd88b79c77 Mon Sep 17 00:00:00 2001 From: Jingyu Ma Date: Tue, 24 Mar 2026 21:45:50 -0700 Subject: [PATCH] fix: restore original function bytes on x86_64 when mock is dropped When a ThreadRegistration is dropped and ref_count reaches 0, restore the original function bytes, free the JIT dispatcher and trampoline memory, and remove the registry entry. This eliminates the permanent dispatcher overhead for unmocked functions. Previously, after a mock was dropped, the JIT dispatcher remained installed at the function's entry point. While functionally correct (the dispatcher returned the original behavior via a trampoline), this caused stack overflow in downstream projects when running 500+ tests sequentially. Deep call chains through previously-patched functions accumulated dispatcher overhead (get_thread_target + catch_unwind + TLS lookup per call) that exceeded the default stack. The restoration is gated behind cfg(target_arch = "x86_64") since the ARM64/ARM32 concern about asynchronous instruction cache invalidation does not apply to x86_64's coherent instruction cache. Includes regression tests that verify: - Function machine code bytes are restored after mock is dropped - Repeated patch/drop cycles work correctly (registry cleanup) - Function behavior is correct across the mock lifecycle Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/injector_core/thread_local_registry.rs | 47 +++--- tests/stack_overflow.rs | 157 +++++++++++++++++++++ 2 files changed, 188 insertions(+), 16 deletions(-) create mode 100644 tests/stack_overflow.rs diff --git a/src/injector_core/thread_local_registry.rs b/src/injector_core/thread_local_registry.rs index 9c7b814..9323e3c 100644 --- a/src/injector_core/thread_local_registry.rs +++ b/src/injector_core/thread_local_registry.rs @@ -116,25 +116,40 @@ impl Drop for ThreadRegistration { } // Decrement ref_count in global registry. - // We intentionally do NOT restore the original function bytes or free the - // dispatcher/trampoline when ref_count reaches 0. The dispatcher remains - // patched into the function permanently. When no thread has a replacement - // registered, the dispatcher routes through the trampoline to the original - // function, preserving correct behavior. - // - // This avoids a race condition on ARM64 (and theoretically x86_64) where - // restoring the original bytes and freeing the dispatcher/trampoline can - // race with another CPU core still executing inside the dispatcher or - // trampoline from a prior call. On ARM64, instruction cache invalidation - // is asynchronous across cores, so another core may still be fetching - // pre-invalidation instructions when the memory is freed. - // - // The tradeoff is a small amount of leaked JIT memory (~300 bytes per - // unique function ever patched) and a minor overhead for calling unpatched - // functions (one TLS lookup per call). Both are negligible for test code. let mut registry = REGISTRY.lock().unwrap_or_else(|e| e.into_inner()); if let Some(entry) = registry.get_mut(&self.method_key) { entry.ref_count = entry.ref_count.saturating_sub(1); + + // On x86_64, restore original function bytes and free JIT memory when + // no thread has an active replacement. This eliminates per-call overhead + // (dispatcher + TLS lookup) for unpatched functions, preventing stack + // overflow when many functions are patched across sequential tests. + // + // On ARM64/ARM32, we intentionally do NOT restore to avoid a race + // condition where instruction cache invalidation is asynchronous across + // cores — another core may still be executing pre-invalidation + // instructions when the memory is freed. + #[cfg(target_arch = "x86_64")] + if entry.ref_count == 0 { + unsafe { + // Restore original function bytes + patch_function(entry.func_ptr, &entry.original_bytes[..entry.patch_size]); + + // Free dispatcher JIT memory + if !entry.dispatcher_jit.is_null() { + free_jit_block(entry.dispatcher_jit, entry.dispatcher_jit_size); + } + + // Free trampoline JIT memory + if !entry.trampoline.is_null() { + free_jit_block(entry.trampoline, entry.trampoline_size); + } + } + + // Remove the entry from the registry so it can be re-patched fresh + // if the same function is faked again in a later test. + registry.remove(&self.method_key); + } } } } diff --git a/tests/stack_overflow.rs b/tests/stack_overflow.rs new file mode 100644 index 0000000..29a27f7 --- /dev/null +++ b/tests/stack_overflow.rs @@ -0,0 +1,157 @@ +// Regression tests for function restoration after mock is dropped (x86_64 Windows). +// +// In v0.5.0, thread-local dispatch permanently patches functions: after a mock +// is dropped, the JIT dispatcher jump instruction remains at the function's entry +// point. While functionally correct (the dispatcher returns the original behavior +// via a trampoline), this has two problems: +// +// 1. **Stack overhead**: Each call through the dispatcher temporarily pushes extra +// stack for get_thread_target + catch_unwind + TLS lookup. In deep call chains +// close to the stack limit, this can trigger stack overflow. This was observed +// in the acs_media_sdk test suite where test #209/503 crashed with +// STATUS_STACK_OVERFLOW (0xC00000FD) when running sequentially. +// +// 2. **Resource leak**: JIT memory (dispatcher + trampoline) is never freed, +// and registry entries accumulate across tests. +// +// The fix restores original function bytes on x86_64 when ref_count drops to 0, +// and frees associated JIT memory. +#![cfg(target_arch = "x86_64")] +#![cfg(target_os = "windows")] + +use injectorpp::interface::injector::*; + +#[inline(never)] +fn recursive_func(depth: u32) -> u32 { + if depth == 0 { + return 0; + } + std::hint::black_box(recursive_func(depth - 1)) + 1 +} + +#[inline(never)] +fn fake_recursive(_depth: u32) -> u32 { + 0 +} + +/// Read the first `n` bytes of machine code at a function's entry point. +unsafe fn read_func_bytes(func_ptr: *const u8, n: usize) -> Vec { + std::slice::from_raw_parts(func_ptr, n).to_vec() +} + +/// Test that the function's machine code is fully restored after mock is dropped. +/// +/// Without the fix: the entry point remains a JMP to the dispatcher (0xE9 ...) +/// With the fix: the original prologue bytes are restored +#[test] +fn test_function_bytes_restored_after_drop() { + let func_ptr = recursive_func as *const u8; + + // Read original bytes before any patching + let original_bytes = unsafe { read_func_bytes(func_ptr, 16) }; + + // The first byte should NOT be 0xE9 (JMP rel32) before patching + assert_ne!( + original_bytes[0], 0xE9, + "Function should not start with JMP before patching" + ); + + { + let mut injector = InjectorPP::new(); + injector + .when_called(injectorpp::func!(fn(recursive_func)(u32) -> u32)) + .will_execute_raw(injectorpp::func!(fn(fake_recursive)(u32) -> u32)); + + // While patched, the first bytes should be a JMP (0xE9 for rel32) + let patched_bytes = unsafe { read_func_bytes(func_ptr, 16) }; + assert_eq!( + patched_bytes[0], 0xE9, + "Patched function should start with JMP rel32 (0xE9), got 0x{:02X}", + patched_bytes[0] + ); + } + + // After drop, original bytes should be restored + let restored_bytes = unsafe { read_func_bytes(func_ptr, 16) }; + assert_eq!( + restored_bytes, original_bytes, + "Function bytes should be fully restored after mock is dropped.\n\ + Original: {:02X?}\n\ + Restored: {:02X?}", + original_bytes, restored_bytes + ); +} + +/// Test that a function can be correctly patched, dropped, and re-patched +/// multiple times. This verifies that registry cleanup allows fresh re-patching. +#[test] +fn test_repeated_patch_and_restore_cycles() { + let func_ptr = recursive_func as *const u8; + let original_bytes = unsafe { read_func_bytes(func_ptr, 16) }; + + for i in 0..20 { + // Before patching: original bytes + let before = unsafe { read_func_bytes(func_ptr, 16) }; + assert_eq!( + before, original_bytes, + "Cycle {}: bytes should match original before patching", + i + ); + + { + let mut injector = InjectorPP::new(); + injector + .when_called(injectorpp::func!(fn(recursive_func)(u32) -> u32)) + .will_execute_raw(injectorpp::func!(fn(fake_recursive)(u32) -> u32)); + + // While patched: should have JMP + let during = unsafe { read_func_bytes(func_ptr, 5) }; + assert_eq!(during[0], 0xE9, "Cycle {}: should be patched with JMP", i); + + // Mock should be active + assert_eq!(recursive_func(3), 0, "Cycle {}: mock should return 0", i); + } + + // After drop: original bytes restored + let after = unsafe { read_func_bytes(func_ptr, 16) }; + assert_eq!( + after, original_bytes, + "Cycle {}: bytes should be restored after drop", + i + ); + assert_eq!( + recursive_func(3), + 3, + "Cycle {}: function should work normally", + i + ); + } +} + +/// Verify behavior: the function returns correct values before, during, and +/// after mocking, including on a thread with limited stack. +#[test] +fn test_function_behavior_across_mock_lifecycle() { + assert_eq!(recursive_func(10), 10, "Should work before patching"); + + { + let mut injector = InjectorPP::new(); + injector + .when_called(injectorpp::func!(fn(recursive_func)(u32) -> u32)) + .will_execute_raw(injectorpp::func!(fn(fake_recursive)(u32) -> u32)); + assert_eq!(recursive_func(10), 0, "Should return mock value during patch"); + } + + assert_eq!(recursive_func(10), 10, "Should return original value after drop"); + + // Deep recursion should also work after restoration + let handle = std::thread::Builder::new() + .stack_size(256 * 1024) + .spawn(move || recursive_func(2000)) + .unwrap(); + assert_eq!( + handle.join().unwrap(), + 2000, + "Deep recursion should work after restoration" + ); +}