Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ description = "Injectorpp is a powerful tool designed to facilitate the writing

[dependencies]
libc = "0.2"
injectorpp-macros = { path = "injectorpp-macros" }

[target.'cfg(target_os = "macos")'.dependencies]
mach2 = "0.5"
Expand Down
16 changes: 16 additions & 0 deletions injectorpp-macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[package]
name = "injectorpp-macros"
version = "0.5.0"
authors = ["Jingyu Ma <mazong1123us@gmail.com>"]
license = "MIT"
repository = "https://github.com/microsoft/injectorppforrust"
edition = "2021"
description = "Proc macros for injectorpp - compile-time lifetime safety checks"

[lib]
proc-macro = true

[dependencies]
syn = { version = "2", features = ["full"] }
quote = "1"
proc-macro2 = "1"
177 changes: 177 additions & 0 deletions injectorpp-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::{Expr, Token, Type, parenthesized, punctuated::Punctuated};

/// Parsed input for the simplified `fn` syntax of `func!`.
struct FuncInput {
func_expr: Expr,
arg_types: Vec<Type>,
return_type: Option<Type>,
is_unsafe: bool,
extern_abi: Option<String>,
}

impl Parse for FuncInput {
fn parse(input: ParseStream) -> syn::Result<Self> {
// Strip optional "func_info:" prefix
if input.peek(syn::Ident) {
let fork = input.fork();
let ident: syn::Ident = fork.parse()?;
if ident == "func_info" && fork.peek(Token![:]) {
// Consume "func_info:"
let _: syn::Ident = input.parse()?;
let _: Token![:] = input.parse()?;
}
}

// Parse optional "unsafe"
let is_unsafe = input.peek(Token![unsafe]);
if is_unsafe {
let _: Token![unsafe] = input.parse()?;
// Skip optional empty braces "{}" used in some macro arms
if input.peek(syn::token::Brace) {
let content;
syn::braced!(content in input);
let _ = content;
}
}

// Parse optional extern "ABI"
let extern_abi = if input.peek(Token![extern]) {
let _: Token![extern] = input.parse()?;
let abi: syn::LitStr = input.parse()?;
Some(abi.value())
} else {
None
};

// Parse "fn"
let _: Token![fn] = input.parse()?;

// Parse "( func_expr )"
let func_content;
parenthesized!(func_content in input);
let func_expr: Expr = func_content.parse()?;

// Parse "( arg_types )"
let types_content;
parenthesized!(types_content in input);
let arg_types: Punctuated<Type, Token![,]> =
Punctuated::parse_terminated(&types_content)?;

// Parse optional "-> return_type"
let return_type = if input.peek(Token![->]) {
let _: Token![->] = input.parse()?;
Some(input.parse::<Type>()?)
} else {
None
};

Ok(FuncInput {
func_expr,
arg_types: arg_types.into_iter().collect(),
return_type,
is_unsafe,
extern_abi,
})
}
}

/// Check if a type is a bare reference (& without explicit lifetime).
fn is_bare_reference(ty: &Type) -> bool {
match ty {
Type::Reference(ref_type) => ref_type.lifetime.is_none(),
_ => false,
}
}

/// Proc macro that implements the simplified `fn` syntax for `func!` with
/// automatic compile-time lifetime safety checks.
///
/// When the return type is a bare reference (e.g., `&str`, `&[u8]`), this macro
/// generates an invariance-based check that detects lifetime mismatches at compile
/// time. This catches the issue #73 pattern where `fn(&str) -> &'static str` is
/// incorrectly specified as `fn(&str) -> &str`.
///
/// For functions with genuinely linked lifetimes (e.g., `fn(&str) -> &str` where
/// the output lifetime matches the input), use an explicit lifetime annotation:
/// `func!(fn (f)(&str) -> &'_ str)`.
#[proc_macro]
pub fn func_checked(input: TokenStream) -> TokenStream {
let parsed = match syn::parse::<FuncInput>(input) {
Ok(parsed) => parsed,
Err(err) => return err.to_compile_error().into(),
};

let func_expr = &parsed.func_expr;
let arg_types = &parsed.arg_types;

// Build the function pointer type
let fn_type = match (&parsed.return_type, parsed.is_unsafe, &parsed.extern_abi) {
(Some(ret), false, None) => quote! { fn(#(#arg_types),*) -> #ret },
(None, false, None) => quote! { fn(#(#arg_types),*) },
(Some(ret), true, None) => quote! { unsafe fn(#(#arg_types),*) -> #ret },
(None, true, None) => quote! { unsafe fn(#(#arg_types),*) -> () },
(Some(ret), true, Some(abi)) => {
let abi_lit = syn::LitStr::new(abi, proc_macro2::Span::call_site());
quote! { unsafe extern #abi_lit fn(#(#arg_types),*) -> #ret }
}
(None, true, Some(abi)) => {
let abi_lit = syn::LitStr::new(abi, proc_macro2::Span::call_site());
quote! { unsafe extern #abi_lit fn(#(#arg_types),*) -> () }
}
(Some(ret), false, Some(abi)) => {
let abi_lit = syn::LitStr::new(abi, proc_macro2::Span::call_site());
quote! { extern #abi_lit fn(#(#arg_types),*) -> #ret }
}
(None, false, Some(abi)) => {
let abi_lit = syn::LitStr::new(abi, proc_macro2::Span::call_site());
quote! { extern #abi_lit fn(#(#arg_types),*) }
}
};

// Generate the lifetime invariance check for bare reference returns.
// This only applies to non-unsafe functions (unsafe/extern functions
// typically don't have Rust lifetime semantics).
let lifetime_check = if !parsed.is_unsafe {
if let Some(ref ret) = parsed.return_type {
if is_bare_reference(ret) {
quote! {
{
fn __injpp_check_ret<__R>(
_f: fn(#(#arg_types),*) -> __R,
) -> fn(#(#arg_types),*) -> __R {
_f
}
fn __injpp_eq<__T>(_: &mut __T, _: &mut __T) {}
let mut __a = __injpp_check_ret(#func_expr);
let mut __b: fn(#(#arg_types),*) -> #ret = #func_expr;
__injpp_eq(&mut __a, &mut __b);
}
}
} else {
quote! {}
}
} else {
quote! {}
}
} else {
quote! {}
};

let output = quote! {
{
#lifetime_check
{
let fn_val: #fn_type = #func_expr;
let ptr = fn_val as *const ();
let sig = std::any::type_name_of_val(&fn_val);
let type_id = std::any::TypeId::of::<#fn_type>();
unsafe { FuncPtr::new_with_type_id(ptr, sig, type_id) }
}
}
};

output.into()
}
23 changes: 23 additions & 0 deletions src/interface/func_ptr.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::injector_core::common::FuncPtrInternal;
use std::any::TypeId;
use std::ptr::NonNull;

/// A safe wrapper around a raw function pointer.
Expand All @@ -17,6 +18,7 @@ pub struct FuncPtr {
/// This is a wrapper around a non-null pointer to ensure safety.
pub(super) func_ptr_internal: FuncPtrInternal,
pub(super) signature: &'static str,
pub(super) type_id: Option<TypeId>,
}

impl FuncPtr {
Expand All @@ -35,6 +37,27 @@ impl FuncPtr {
Self {
func_ptr_internal: FuncPtrInternal::new(nn),
signature,
type_id: None,
}
}

/// Creates a new `FuncPtr` from a raw pointer with type identity information.
///
/// # Safety
///
/// The caller must ensure that the pointer is valid and points to a function.
pub unsafe fn new_with_type_id(
ptr: *const (),
signature: &'static str,
type_id: TypeId,
) -> Self {
let p = ptr as *mut ();
let nn = NonNull::new(p).expect("Pointer must not be null");

Self {
func_ptr_internal: FuncPtrInternal::new(nn),
signature,
type_id: Some(type_id),
}
}
}
65 changes: 49 additions & 16 deletions src/interface/injector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::injector_core::common::*;
use crate::injector_core::internal::*;
pub use crate::interface::func_ptr::FuncPtr;
pub use crate::interface::macros::__assert_future_output;
pub use crate::interface::macros::__type_id_of_val;
pub use crate::interface::verifier::CallCountVerifier;

use std::future::Future;
Expand Down Expand Up @@ -166,6 +167,7 @@ impl InjectorPP {
lib: self,
when,
expected_signature: func.signature,
expected_type_id: func.type_id,
}
}

Expand Down Expand Up @@ -211,6 +213,7 @@ impl InjectorPP {
lib: self,
when,
expected_signature: "",
expected_type_id: None,
}
}

Expand Down Expand Up @@ -254,15 +257,16 @@ impl InjectorPP {
F: Future<Output = T>,
{
let poll_fn: fn(Pin<&mut F>, &mut Context<'_>) -> Poll<T> = <F as Future>::poll;
let when = WhenCalled::new(
crate::func!(poll_fn, fn(Pin<&mut F>, &mut Context<'_>) -> Poll<T>).func_ptr_internal,
);
let when = WhenCalled::new(unsafe {
FuncPtr::new(poll_fn as *const (), std::any::type_name_of_val(&poll_fn))
}.func_ptr_internal);

let signature = fake_pair.1;
WhenCalledBuilderAsync {
lib: self,
when,
expected_signature: signature,
expected_type_id: None,
}
}

Expand Down Expand Up @@ -313,14 +317,15 @@ impl InjectorPP {
F: Future<Output = T>,
{
let poll_fn: fn(Pin<&mut F>, &mut Context<'_>) -> Poll<T> = <F as Future>::poll;
let when = WhenCalled::new(
crate::func!(poll_fn, fn(Pin<&mut F>, &mut Context<'_>) -> Poll<T>).func_ptr_internal,
);
let when = WhenCalled::new(unsafe {
FuncPtr::new(poll_fn as *const (), std::any::type_name_of_val(&poll_fn))
}.func_ptr_internal);

WhenCalledBuilderAsync {
lib: self,
when,
expected_signature: "",
expected_type_id: None,
}
}
}
Expand Down Expand Up @@ -356,6 +361,7 @@ pub struct WhenCalledBuilder<'a> {
lib: &'a mut InjectorPP,
when: WhenCalled,
expected_signature: &'static str,
expected_type_id: Option<std::any::TypeId>,
}

impl WhenCalledBuilder<'_> {
Expand Down Expand Up @@ -403,11 +409,24 @@ impl WhenCalledBuilder<'_> {
/// assert!(Path::new("/nonexistent").exists());
/// ```
pub fn will_execute_raw(self, target: FuncPtr) {
if normalize_signature(target.signature) != normalize_signature(self.expected_signature) {
panic!(
"Signature mismatch: expected {:?} but got {:?}",
self.expected_signature, target.signature
);
match (self.expected_type_id, target.type_id) {
(Some(expected), Some(actual)) if expected != actual => {
panic!(
"Signature mismatch: expected {:?} but got {:?}",
self.expected_signature, target.signature
);
}
(None, _) | (_, None) => {
if normalize_signature(target.signature)
!= normalize_signature(self.expected_signature)
{
panic!(
"Signature mismatch: expected {:?} but got {:?}",
self.expected_signature, target.signature
);
}
}
_ => {}
}

#[cfg(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm"))]
Expand Down Expand Up @@ -577,6 +596,7 @@ pub struct WhenCalledBuilderAsync<'a> {
lib: &'a mut InjectorPP,
when: WhenCalled,
expected_signature: &'static str,
expected_type_id: Option<std::any::TypeId>,
}

impl WhenCalledBuilderAsync<'_> {
Expand Down Expand Up @@ -605,11 +625,24 @@ impl WhenCalledBuilderAsync<'_> {
/// }
/// ```
pub fn will_return_async(self, target: FuncPtr) {
if normalize_signature(target.signature) != normalize_signature(self.expected_signature) {
panic!(
"Signature mismatch: expected {:?} but got {:?}",
self.expected_signature, target.signature
);
match (self.expected_type_id, target.type_id) {
(Some(expected), Some(actual)) if expected != actual => {
panic!(
"Signature mismatch: expected {:?} but got {:?}",
self.expected_signature, target.signature
);
}
(None, _) | (_, None) => {
if normalize_signature(target.signature)
!= normalize_signature(self.expected_signature)
{
panic!(
"Signature mismatch: expected {:?} but got {:?}",
self.expected_signature, target.signature
);
}
}
_ => {}
}

#[cfg(any(target_arch = "x86_64", target_arch = "aarch64", target_arch = "arm"))]
Expand Down
Loading
Loading