Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 97 additions & 1 deletion CUSTOM_FUNCTIONS.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ m = greaterThan(r.age, 18) && stringContains(r.path, p.path)

## OperatorFunction Variants

The `OperatorFunction` enum supports functions with 0 to 6 arguments:
The `OperatorFunction` enum supports functions with 0 to 6 arguments. There are two types of variants:

### Function Pointer Variants (Stateless)

These use simple function pointers and cannot capture external state:

- `Arg0`: `fn() -> Dynamic`
- `Arg1`: `fn(Dynamic) -> Dynamic`
Expand All @@ -147,6 +151,98 @@ The `OperatorFunction` enum supports functions with 0 to 6 arguments:
- `Arg5`: `fn(Dynamic, Dynamic, Dynamic, Dynamic, Dynamic) -> Dynamic`
- `Arg6`: `fn(Dynamic, Dynamic, Dynamic, Dynamic, Dynamic, Dynamic) -> Dynamic`

### Closure Variants (Can Capture State)

These use `Arc<dyn Fn>` and can capture external state like database connections:

- `Arg0Closure`: `Arc<dyn Fn() -> Dynamic + Send + Sync>`
- `Arg1Closure`: `Arc<dyn Fn(Dynamic) -> Dynamic + Send + Sync>`
- `Arg2Closure`: `Arc<dyn Fn(Dynamic, Dynamic) -> Dynamic + Send + Sync>`
- `Arg3Closure`: `Arc<dyn Fn(Dynamic, Dynamic, Dynamic) -> Dynamic + Send + Sync>`
- `Arg4Closure`: `Arc<dyn Fn(Dynamic, Dynamic, Dynamic, Dynamic) -> Dynamic + Send + Sync>`
- `Arg5Closure`: `Arc<dyn Fn(Dynamic, Dynamic, Dynamic, Dynamic, Dynamic) -> Dynamic + Send + Sync>`
- `Arg6Closure`: `Arc<dyn Fn(Dynamic, Dynamic, Dynamic, Dynamic, Dynamic, Dynamic) -> Dynamic + Send + Sync>`

## Capturing External State with Closures

One of the key features of the closure variants is the ability to capture external state. This is useful when you need to access application state (like database connections) from within your custom functions.

### Example: Capturing Database Connection in Axum

```rust
use casbin::{CoreApi, Enforcer, OperatorFunction};
use rhai::Dynamic;
use std::sync::Arc;

#[derive(Clone)]
struct AppState {
db_connection: Arc<DatabaseConnection>,
casbin_enforcer: Arc<Mutex<Enforcer>>,
}

impl AppState {
async fn setup_enforcer_with_db_check(&mut self) {
// Clone the Arc to capture it in the closure
let db_conn = self.db_connection.clone();

// Create a closure that captures the database connection
let check_fn = Arc::new(move |product_id: Dynamic| {
let product_id_int = product_id.as_int().unwrap_or(0);

// Access the database connection from the captured state
// Note: In real code, you'd need to handle async properly
let has_storages = db_conn.check_product_has_storages(product_id_int);

has_storages.into()
});

// Register the closure-based function
if let Ok(mut enforcer) = self.casbin_enforcer.lock() {
enforcer.add_function(
"matchProductHasStorages",
OperatorFunction::Arg1Closure(check_fn),
);
}
}
}
```

### Example: Multi-Argument Closure with Shared State

```rust
use casbin::{CoreApi, Enforcer, OperatorFunction};
use rhai::Dynamic;
use std::sync::Arc;
use std::collections::HashMap;

// Simulate external configuration
let prefix_map: Arc<HashMap<String, String>> = Arc::new({
let mut m = HashMap::new();
m.insert("data1".to_string(), "/api/v1/".to_string());
m.insert("data2".to_string(), "/api/v2/".to_string());
m
});

// Clone for the closure
let prefix_clone = prefix_map.clone();

// Create a two-argument closure that uses the shared state
e.add_function(
"customPathCheck",
OperatorFunction::Arg2Closure(Arc::new(move |request_path: Dynamic, policy_resource: Dynamic| {
let req_path = request_path.to_string();
let policy_res = policy_resource.to_string();

// Use the captured prefix_map to determine if paths match
if let Some(prefix) = prefix_clone.get(&policy_res) {
req_path.starts_with(prefix).into()
} else {
(req_path == policy_res).into()
}
})),
);
```

## Working with Dynamic Types

Rhai's `Dynamic` type provides several methods to extract values:
Expand Down
219 changes: 216 additions & 3 deletions src/enforcer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,56 @@ impl Enforcer {
OperatorFunction::Arg6(func) => {
engine.register_fn(key, func);
}
// Closure variants
OperatorFunction::Arg0Closure(func) => {
engine.register_fn(key, move || func());
}
OperatorFunction::Arg1Closure(func) => {
engine.register_fn(key, move |a: Dynamic| func(a));
}
OperatorFunction::Arg2Closure(func) => {
engine
.register_fn(key, move |a: Dynamic, b: Dynamic| func(a, b));
}
OperatorFunction::Arg3Closure(func) => {
engine.register_fn(
key,
move |a: Dynamic, b: Dynamic, c: Dynamic| func(a, b, c),
);
}
OperatorFunction::Arg4Closure(func) => {
engine.register_fn(
key,
move |a: Dynamic, b: Dynamic, c: Dynamic, d: Dynamic| {
func(a, b, c, d)
},
);
}
OperatorFunction::Arg5Closure(func) => {
engine.register_fn(
key,
move |a: Dynamic,
b: Dynamic,
c: Dynamic,
d: Dynamic,
e: Dynamic| {
func(a, b, c, d, e)
},
);
}
OperatorFunction::Arg6Closure(func) => {
engine.register_fn(
key,
move |a: Dynamic,
b: Dynamic,
c: Dynamic,
d: Dynamic,
e: Dynamic,
g: Dynamic| {
func(a, b, c, d, e, g)
},
);
}
}
}

Expand Down Expand Up @@ -434,8 +484,8 @@ impl CoreApi for Enforcer {

engine.register_global_module(CASBIN_PACKAGE.as_shared_module());

for (key, &func) in fm.get_functions() {
Self::register_function(&mut engine, key, func);
for (key, func) in fm.get_functions() {
Self::register_function(&mut engine, key, func.clone());
}

let mut e = Self {
Expand Down Expand Up @@ -488,7 +538,7 @@ impl CoreApi for Enforcer {

#[inline]
fn add_function(&mut self, fname: &str, f: OperatorFunction) {
self.fm.add_function(fname, f);
self.fm.add_function(fname, f.clone());
Self::register_function(&mut self.engine, fname, f);
}

Expand Down Expand Up @@ -1962,4 +2012,167 @@ m = r.sub == p.sub && r.obj == p.obj && r.act == p.act
assert_eq!(false, e.enforce(("alice", "data1", "write")).unwrap());
assert_eq!(false, e.enforce(("bob", "data1", "read")).unwrap());
}

#[cfg(not(target_arch = "wasm32"))]
#[cfg_attr(
all(feature = "runtime-async-std", not(target_arch = "wasm32")),
async_std::test
)]
#[cfg_attr(
all(feature = "runtime-tokio", not(target_arch = "wasm32")),
tokio::test
)]
async fn test_custom_function_with_closure_capturing_state() {
use crate::prelude::*;
use std::sync::Arc;

let m = DefaultModel::from_str(
r#"
[request_definition]
r = sub, obj, act

[policy_definition]
p = sub, obj, act

[policy_effect]
e = some(where (p.eft == allow))

[matchers]
m = r.sub == p.sub && r.obj == p.obj && r.act == p.act && checkPermission(r.act)
"#,
)
.await
.unwrap();

let adapter = MemoryAdapter::default();
let mut e = Enforcer::new(m, adapter).await.unwrap();

// Simulate external state (like a database connection or app state)
// This is the key feature: closures can capture external state
let allowed_actions: Arc<Vec<String>> =
Arc::new(vec!["read".to_string(), "view".to_string()]);

// Create a closure that captures the allowed_actions
let allowed_clone = allowed_actions.clone();
e.add_function(
"checkPermission",
OperatorFunction::Arg1Closure(Arc::new(move |action: Dynamic| {
let action_str = action.to_string();
allowed_clone.contains(&action_str).into()
})),
);

e.add_policy(vec![
"alice".to_owned(),
"data1".to_owned(),
"read".to_owned(),
])
.await
.unwrap();

e.add_policy(vec![
"alice".to_owned(),
"data1".to_owned(),
"write".to_owned(),
])
.await
.unwrap();

// "read" is in allowed_actions and has a policy, so this should be allowed
assert_eq!(true, e.enforce(("alice", "data1", "read")).unwrap());

// "write" is NOT in allowed_actions, so this should be denied
// even though there's a policy for it
assert_eq!(false, e.enforce(("alice", "data1", "write")).unwrap());

// "view" is in allowed_actions but there's no policy for it
assert_eq!(false, e.enforce(("alice", "data1", "view")).unwrap());

// "delete" is not in allowed_actions and has no policy
assert_eq!(false, e.enforce(("alice", "data1", "delete")).unwrap());
}

#[cfg(not(target_arch = "wasm32"))]
#[cfg_attr(
all(feature = "runtime-async-std", not(target_arch = "wasm32")),
async_std::test
)]
#[cfg_attr(
all(feature = "runtime-tokio", not(target_arch = "wasm32")),
tokio::test
)]
async fn test_custom_function_with_closure_two_args() {
use crate::prelude::*;
use std::collections::HashMap;
use std::sync::Arc;

let m = DefaultModel::from_str(
r#"
[request_definition]
r = sub, obj, act

[policy_definition]
p = sub, obj, act

[policy_effect]
e = some(where (p.eft == allow))

[matchers]
m = r.sub == p.sub && customCheck(r.obj, p.obj)
"#,
)
.await
.unwrap();

let adapter = MemoryAdapter::default();
let mut e = Enforcer::new(m, adapter).await.unwrap();

// Simulate external configuration that affects matching logic
let prefix_map: Arc<HashMap<String, String>> = Arc::new({
let mut m = HashMap::new();
m.insert("data1".to_string(), "/api/v1/".to_string());
m.insert("data2".to_string(), "/api/v2/".to_string());
m
});

// Create a closure that captures external state
let prefix_clone = prefix_map.clone();
e.add_function(
"customCheck",
OperatorFunction::Arg2Closure(Arc::new(
move |r_obj: Dynamic, p_obj: Dynamic| {
let r_obj_str = r_obj.to_string();
let p_obj_str = p_obj.to_string();

// Check if r_obj starts with the prefix mapped from p_obj
if let Some(prefix) = prefix_clone.get(&p_obj_str) {
r_obj_str.starts_with(prefix).into()
} else {
(r_obj_str == p_obj_str).into()
}
},
)),
);

e.add_policy(vec![
"alice".to_owned(),
"data1".to_owned(),
"read".to_owned(),
])
.await
.unwrap();

// Request obj "/api/v1/users" should match policy obj "data1"
// because prefix_map["data1"] = "/api/v1/"
assert_eq!(
true,
e.enforce(("alice", "/api/v1/users", "read")).unwrap()
);

// Request obj "/api/v2/users" should NOT match policy obj "data1"
assert_eq!(
false,
e.enforce(("alice", "/api/v2/users", "read")).unwrap()
);
}
}
Loading
Loading