From fec3119617d2b8421ee5cdab34ec385243cc5ff9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 2 Dec 2025 04:29:15 +0000 Subject: [PATCH 1/3] Initial plan From f2fec0bcf666b8c5b375111ff9d59bfacd20f955 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 2 Dec 2025 04:42:15 +0000 Subject: [PATCH 2/3] Add closure variants to OperatorFunction for capturing external state Co-authored-by: hsluoyz <3787410+hsluoyz@users.noreply.github.com> --- CUSTOM_FUNCTIONS.md | 98 ++++++++++++++++- src/enforcer.rs | 218 +++++++++++++++++++++++++++++++++++++- src/model/function_map.rs | 61 +++++++++-- 3 files changed, 363 insertions(+), 14 deletions(-) diff --git a/CUSTOM_FUNCTIONS.md b/CUSTOM_FUNCTIONS.md index ab0111ea..37a189e0 100644 --- a/CUSTOM_FUNCTIONS.md +++ b/CUSTOM_FUNCTIONS.md @@ -137,7 +137,11 @@ m = greaterThan(r.age, 18) && stringContains(r.path, p.path) ## OperatorFunction Variants -The `OperatorFunction` enum supports functions with 0 to 6 arguments: +The `OperatorFunction` enum supports functions with 0 to 6 arguments. There are two types of variants: + +### Function Pointer Variants (Stateless) + +These use simple function pointers and cannot capture external state: - `Arg0`: `fn() -> Dynamic` - `Arg1`: `fn(Dynamic) -> Dynamic` @@ -147,6 +151,98 @@ The `OperatorFunction` enum supports functions with 0 to 6 arguments: - `Arg5`: `fn(Dynamic, Dynamic, Dynamic, Dynamic, Dynamic) -> Dynamic` - `Arg6`: `fn(Dynamic, Dynamic, Dynamic, Dynamic, Dynamic, Dynamic) -> Dynamic` +### Closure Variants (Can Capture State) + +These use `Arc` and can capture external state like database connections: + +- `Arg0Closure`: `Arc Dynamic + Send + Sync>` +- `Arg1Closure`: `Arc Dynamic + Send + Sync>` +- `Arg2Closure`: `Arc Dynamic + Send + Sync>` +- `Arg3Closure`: `Arc Dynamic + Send + Sync>` +- `Arg4Closure`: `Arc Dynamic + Send + Sync>` +- `Arg5Closure`: `Arc Dynamic + Send + Sync>` +- `Arg6Closure`: `Arc Dynamic + Send + Sync>` + +## Capturing External State with Closures + +One of the key features of the closure variants is the ability to capture external state. This is useful when you need to access application state (like database connections) from within your custom functions. + +### Example: Capturing Database Connection in Axum + +```rust +use casbin::{CoreApi, Enforcer, OperatorFunction}; +use rhai::Dynamic; +use std::sync::Arc; + +#[derive(Clone)] +struct AppState { + db_connection: Arc, + casbin_enforcer: Arc>, +} + +impl AppState { + async fn setup_enforcer_with_db_check(&mut self) { + // Clone the Arc to capture it in the closure + let db_conn = self.db_connection.clone(); + + // Create a closure that captures the database connection + let check_fn = Arc::new(move |product_id: Dynamic| { + let product_id_int = product_id.as_int().unwrap_or(0); + + // Access the database connection from the captured state + // Note: In real code, you'd need to handle async properly + let has_storages = db_conn.check_product_has_storages(product_id_int); + + has_storages.into() + }); + + // Register the closure-based function + if let Ok(mut enforcer) = self.casbin_enforcer.lock() { + enforcer.add_function( + "matchProductHasStorages", + OperatorFunction::Arg1Closure(check_fn), + ); + } + } +} +``` + +### Example: Multi-Argument Closure with Shared State + +```rust +use casbin::{CoreApi, Enforcer, OperatorFunction}; +use rhai::Dynamic; +use std::sync::Arc; +use std::collections::HashMap; + +// Simulate external configuration +let prefix_map: Arc> = Arc::new({ + let mut m = HashMap::new(); + m.insert("data1".to_string(), "/api/v1/".to_string()); + m.insert("data2".to_string(), "/api/v2/".to_string()); + m +}); + +// Clone for the closure +let prefix_clone = prefix_map.clone(); + +// Create a two-argument closure that uses the shared state +e.add_function( + "customPathCheck", + OperatorFunction::Arg2Closure(Arc::new(move |request_path: Dynamic, policy_resource: Dynamic| { + let req_path = request_path.to_string(); + let policy_res = policy_resource.to_string(); + + // Use the captured prefix_map to determine if paths match + if let Some(prefix) = prefix_clone.get(&policy_res) { + req_path.starts_with(prefix).into() + } else { + (req_path == policy_res).into() + } + })), +); +``` + ## Working with Dynamic Types Rhai's `Dynamic` type provides several methods to extract values: diff --git a/src/enforcer.rs b/src/enforcer.rs index c63cf36c..d2e67688 100644 --- a/src/enforcer.rs +++ b/src/enforcer.rs @@ -403,6 +403,56 @@ impl Enforcer { OperatorFunction::Arg6(func) => { engine.register_fn(key, func); } + // Closure variants + OperatorFunction::Arg0Closure(func) => { + engine.register_fn(key, move || func()); + } + OperatorFunction::Arg1Closure(func) => { + engine.register_fn(key, move |a: Dynamic| func(a)); + } + OperatorFunction::Arg2Closure(func) => { + engine + .register_fn(key, move |a: Dynamic, b: Dynamic| func(a, b)); + } + OperatorFunction::Arg3Closure(func) => { + engine.register_fn( + key, + move |a: Dynamic, b: Dynamic, c: Dynamic| func(a, b, c), + ); + } + OperatorFunction::Arg4Closure(func) => { + engine.register_fn( + key, + move |a: Dynamic, b: Dynamic, c: Dynamic, d: Dynamic| { + func(a, b, c, d) + }, + ); + } + OperatorFunction::Arg5Closure(func) => { + engine.register_fn( + key, + move |a: Dynamic, + b: Dynamic, + c: Dynamic, + d: Dynamic, + e: Dynamic| { + func(a, b, c, d, e) + }, + ); + } + OperatorFunction::Arg6Closure(func) => { + engine.register_fn( + key, + move |a: Dynamic, + b: Dynamic, + c: Dynamic, + d: Dynamic, + e: Dynamic, + g: Dynamic| { + func(a, b, c, d, e, g) + }, + ); + } } } @@ -434,8 +484,8 @@ impl CoreApi for Enforcer { engine.register_global_module(CASBIN_PACKAGE.as_shared_module()); - for (key, &func) in fm.get_functions() { - Self::register_function(&mut engine, key, func); + for (key, func) in fm.get_functions() { + Self::register_function(&mut engine, key, func.clone()); } let mut e = Self { @@ -488,7 +538,7 @@ impl CoreApi for Enforcer { #[inline] fn add_function(&mut self, fname: &str, f: OperatorFunction) { - self.fm.add_function(fname, f); + self.fm.add_function(fname, f.clone()); Self::register_function(&mut self.engine, fname, f); } @@ -1962,4 +2012,166 @@ m = r.sub == p.sub && r.obj == p.obj && r.act == p.act assert_eq!(false, e.enforce(("alice", "data1", "write")).unwrap()); assert_eq!(false, e.enforce(("bob", "data1", "read")).unwrap()); } + + #[cfg(not(target_arch = "wasm32"))] + #[cfg_attr( + all(feature = "runtime-async-std", not(target_arch = "wasm32")), + async_std::test + )] + #[cfg_attr( + all(feature = "runtime-tokio", not(target_arch = "wasm32")), + tokio::test + )] + async fn test_custom_function_with_closure_capturing_state() { + use crate::prelude::*; + use std::sync::Arc; + + let m = DefaultModel::from_str( + r#" +[request_definition] +r = sub, obj, act + +[policy_definition] +p = sub, obj, act + +[policy_effect] +e = some(where (p.eft == allow)) + +[matchers] +m = r.sub == p.sub && r.obj == p.obj && r.act == p.act && checkPermission(r.act) +"#, + ) + .await + .unwrap(); + + let adapter = MemoryAdapter::default(); + let mut e = Enforcer::new(m, adapter).await.unwrap(); + + // Simulate external state (like a database connection or app state) + // This is the key feature: closures can capture external state + let allowed_actions: Arc> = + Arc::new(vec!["read".to_string(), "view".to_string()]); + + // Create a closure that captures the allowed_actions + let allowed_clone = allowed_actions.clone(); + e.add_function( + "checkPermission", + OperatorFunction::Arg1Closure(Arc::new(move |action: Dynamic| { + let action_str = action.to_string(); + allowed_clone.contains(&action_str).into() + })), + ); + + e.add_policy(vec![ + "alice".to_owned(), + "data1".to_owned(), + "read".to_owned(), + ]) + .await + .unwrap(); + + e.add_policy(vec![ + "alice".to_owned(), + "data1".to_owned(), + "write".to_owned(), + ]) + .await + .unwrap(); + + // "read" is in allowed_actions and has a policy, so this should be allowed + assert_eq!(true, e.enforce(("alice", "data1", "read")).unwrap()); + + // "write" is NOT in allowed_actions, so this should be denied + // even though there's a policy for it + assert_eq!(false, e.enforce(("alice", "data1", "write")).unwrap()); + + // "view" is in allowed_actions but there's no policy for it + assert_eq!(false, e.enforce(("alice", "data1", "view")).unwrap()); + + // "delete" is not in allowed_actions and has no policy + assert_eq!(false, e.enforce(("alice", "data1", "delete")).unwrap()); + } + + #[cfg(not(target_arch = "wasm32"))] + #[cfg_attr( + all(feature = "runtime-async-std", not(target_arch = "wasm32")), + async_std::test + )] + #[cfg_attr( + all(feature = "runtime-tokio", not(target_arch = "wasm32")), + tokio::test + )] + async fn test_custom_function_with_closure_two_args() { + use crate::prelude::*; + use std::sync::Arc; + + let m = DefaultModel::from_str( + r#" +[request_definition] +r = sub, obj, act + +[policy_definition] +p = sub, obj, act + +[policy_effect] +e = some(where (p.eft == allow)) + +[matchers] +m = r.sub == p.sub && customCheck(r.obj, p.obj) +"#, + ) + .await + .unwrap(); + + let adapter = MemoryAdapter::default(); + let mut e = Enforcer::new(m, adapter).await.unwrap(); + + // Simulate external configuration that affects matching logic + let prefix_map: Arc> = Arc::new({ + let mut m = HashMap::new(); + m.insert("data1".to_string(), "/api/v1/".to_string()); + m.insert("data2".to_string(), "/api/v2/".to_string()); + m + }); + + // Create a closure that captures external state + let prefix_clone = prefix_map.clone(); + e.add_function( + "customCheck", + OperatorFunction::Arg2Closure(Arc::new( + move |r_obj: Dynamic, p_obj: Dynamic| { + let r_obj_str = r_obj.to_string(); + let p_obj_str = p_obj.to_string(); + + // Check if r_obj starts with the prefix mapped from p_obj + if let Some(prefix) = prefix_clone.get(&p_obj_str) { + r_obj_str.starts_with(prefix).into() + } else { + (r_obj_str == p_obj_str).into() + } + }, + )), + ); + + e.add_policy(vec![ + "alice".to_owned(), + "data1".to_owned(), + "read".to_owned(), + ]) + .await + .unwrap(); + + // Request obj "/api/v1/users" should match policy obj "data1" + // because prefix_map["data1"] = "/api/v1/" + assert_eq!( + true, + e.enforce(("alice", "/api/v1/users", "read")).unwrap() + ); + + // Request obj "/api/v2/users" should NOT match policy obj "data1" + assert_eq!( + false, + e.enforce(("alice", "/api/v2/users", "read")).unwrap() + ); + } } diff --git a/src/model/function_map.rs b/src/model/function_map.rs index ee828e87..c57deb41 100644 --- a/src/model/function_map.rs +++ b/src/model/function_map.rs @@ -15,7 +15,7 @@ use rhai::Dynamic; static MAT_B: Lazy = Lazy::new(|| Regex::new(r":[^/]*").unwrap()); static MAT_P: Lazy = Lazy::new(|| Regex::new(r"\{[^/]*\}").unwrap()); -use std::{borrow::Cow, collections::HashMap}; +use std::{borrow::Cow, collections::HashMap, sync::Arc}; /// Represents a custom operator function that can be registered with Casbin. /// @@ -30,29 +30,37 @@ use std::{borrow::Cow, collections::HashMap}; /// /// This allows for flexible custom functions that can work with different types. /// +/// There are two variants for each argument count: +/// - `ArgN`: Uses a simple function pointer (for stateless functions) +/// - `ArgNClosure`: Uses an `Arc` (for closures that capture state) +/// /// # Example /// /// ```rust,ignore /// use casbin::{CoreApi, OperatorFunction}; /// use rhai::Dynamic; +/// use std::sync::Arc; /// -/// // Function that works with integers +/// // Function pointer (stateless) - uses Arg2 /// let int_fn = OperatorFunction::Arg2(|a: Dynamic, b: Dynamic| { /// let a_int = a.as_int().unwrap_or(0); /// let b_int = b.as_int().unwrap_or(0); /// (a_int > b_int).into() /// }); /// -/// // Function that works with strings -/// let str_fn = OperatorFunction::Arg2(|a: Dynamic, b: Dynamic| { -/// use casbin::model::function_map::dynamic_to_str; -/// let a_str = dynamic_to_str(&a); -/// let b_str = dynamic_to_str(&b); -/// a_str.contains(b_str.as_ref()).into() -/// }); +/// // Closure that captures state - uses Arg2Closure +/// let db_connection = Arc::new(some_database_connection); +/// let db_conn_clone = db_connection.clone(); +/// let closure_fn = OperatorFunction::Arg2Closure(Arc::new(move |a: Dynamic, b: Dynamic| { +/// // Access db_conn_clone here +/// let a_str = a.to_string(); +/// let b_str = b.to_string(); +/// (a_str == b_str).into() +/// })); /// ``` -#[derive(Clone, Copy)] +#[derive(Clone)] pub enum OperatorFunction { + // Function pointer variants (stateless) Arg0(fn() -> Dynamic), Arg1(fn(Dynamic) -> Dynamic), Arg2(fn(Dynamic, Dynamic) -> Dynamic), @@ -60,6 +68,39 @@ pub enum OperatorFunction { Arg4(fn(Dynamic, Dynamic, Dynamic, Dynamic) -> Dynamic), Arg5(fn(Dynamic, Dynamic, Dynamic, Dynamic, Dynamic) -> Dynamic), Arg6(fn(Dynamic, Dynamic, Dynamic, Dynamic, Dynamic, Dynamic) -> Dynamic), + // Closure variants (can capture state) + Arg0Closure(Arc Dynamic + Send + Sync>), + Arg1Closure(Arc Dynamic + Send + Sync>), + Arg2Closure(Arc Dynamic + Send + Sync>), + Arg3Closure( + Arc Dynamic + Send + Sync>, + ), + Arg4Closure( + Arc< + dyn Fn(Dynamic, Dynamic, Dynamic, Dynamic) -> Dynamic + Send + Sync, + >, + ), + Arg5Closure( + Arc< + dyn Fn(Dynamic, Dynamic, Dynamic, Dynamic, Dynamic) -> Dynamic + + Send + + Sync, + >, + ), + Arg6Closure( + Arc< + dyn Fn( + Dynamic, + Dynamic, + Dynamic, + Dynamic, + Dynamic, + Dynamic, + ) -> Dynamic + + Send + + Sync, + >, + ), } pub struct FunctionMap { From a1b1d1ac5b1361cd08f5b18d025372e701ab21ba Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 2 Dec 2025 04:46:44 +0000 Subject: [PATCH 3/3] Fix missing HashMap import in test Co-authored-by: hsluoyz <3787410+hsluoyz@users.noreply.github.com> --- src/enforcer.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/enforcer.rs b/src/enforcer.rs index d2e67688..e52c45f1 100644 --- a/src/enforcer.rs +++ b/src/enforcer.rs @@ -2103,6 +2103,7 @@ m = r.sub == p.sub && r.obj == p.obj && r.act == p.act && checkPermission(r.act) )] async fn test_custom_function_with_closure_two_args() { use crate::prelude::*; + use std::collections::HashMap; use std::sync::Arc; let m = DefaultModel::from_str(