|
1 | 1 | use pyo3::prelude::*; |
2 | 2 | use std::sync::Arc; |
| 3 | +use std::time::Duration; |
3 | 4 |
|
4 | 5 | use crate::runtime::TOKIO_RT; |
5 | 6 |
|
| 7 | +/// Python-visible options for Entra ID (Azure AD) authentication. |
| 8 | +/// |
| 9 | +/// All fields are optional; omitting a field uses the duroxide-pg default. |
| 10 | +#[pyclass] |
| 11 | +#[derive(Clone, Default)] |
| 12 | +pub struct PyPostgresEntraOptions { |
| 13 | + pub audience: Option<String>, |
| 14 | + pub max_connections: Option<u32>, |
| 15 | + pub acquire_timeout_ms: Option<u64>, |
| 16 | + pub refresh_interval_ms: Option<u64>, |
| 17 | +} |
| 18 | + |
| 19 | +#[pymethods] |
| 20 | +impl PyPostgresEntraOptions { |
| 21 | + #[new] |
| 22 | + #[pyo3(signature = (*, audience=None, max_connections=None, acquire_timeout_ms=None, refresh_interval_ms=None))] |
| 23 | + fn new( |
| 24 | + audience: Option<String>, |
| 25 | + max_connections: Option<u32>, |
| 26 | + acquire_timeout_ms: Option<u64>, |
| 27 | + refresh_interval_ms: Option<u64>, |
| 28 | + ) -> Self { |
| 29 | + Self { |
| 30 | + audience, |
| 31 | + max_connections, |
| 32 | + acquire_timeout_ms, |
| 33 | + refresh_interval_ms, |
| 34 | + } |
| 35 | + } |
| 36 | +} |
| 37 | + |
| 38 | +impl PyPostgresEntraOptions { |
| 39 | + fn into_entra_auth_options(self) -> duroxide_pg::EntraAuthOptions { |
| 40 | + let mut opts = duroxide_pg::EntraAuthOptions::new(); |
| 41 | + if let Some(aud) = self.audience { |
| 42 | + opts = opts.audience(aud); |
| 43 | + } |
| 44 | + if let Some(mc) = self.max_connections { |
| 45 | + opts = opts.max_connections(mc); |
| 46 | + } |
| 47 | + if let Some(ms) = self.acquire_timeout_ms { |
| 48 | + opts = opts.acquire_timeout(Duration::from_millis(ms)); |
| 49 | + } |
| 50 | + if let Some(ms) = self.refresh_interval_ms { |
| 51 | + opts = opts.refresh_interval(Duration::from_millis(ms)); |
| 52 | + } |
| 53 | + opts |
| 54 | + } |
| 55 | +} |
| 56 | + |
6 | 57 | /// Wraps duroxide-pg's PostgresProvider for use from Python. |
7 | 58 | #[pyclass] |
8 | 59 | pub struct PyPostgresProvider { |
@@ -46,4 +97,72 @@ impl PyPostgresProvider { |
46 | 97 | inner: Arc::new(provider), |
47 | 98 | }) |
48 | 99 | } |
| 100 | + |
| 101 | + /// Connect to Azure Database for PostgreSQL using Microsoft Entra ID |
| 102 | + /// (Azure AD) token authentication. The runtime fetches and refreshes |
| 103 | + /// the token automatically via the DefaultAzureCredential chain. |
| 104 | + /// |
| 105 | + /// `user` must be the Entra principal name mapped to a PostgreSQL role |
| 106 | + /// on the server. Pass `None` for `options` to use defaults. |
| 107 | + #[staticmethod] |
| 108 | + #[pyo3(signature = (host, port, database, user, options=None))] |
| 109 | + fn connect_with_entra( |
| 110 | + host: String, |
| 111 | + port: u16, |
| 112 | + database: String, |
| 113 | + user: String, |
| 114 | + options: Option<PyPostgresEntraOptions>, |
| 115 | + ) -> PyResult<Self> { |
| 116 | + let entra_opts = options.unwrap_or_default().into_entra_auth_options(); |
| 117 | + let provider = TOKIO_RT |
| 118 | + .block_on(async { |
| 119 | + duroxide_pg::PostgresProvider::new_with_entra( |
| 120 | + &host, port, &database, &user, entra_opts, |
| 121 | + ) |
| 122 | + .await |
| 123 | + }) |
| 124 | + .map_err(|e| { |
| 125 | + pyo3::exceptions::PyRuntimeError::new_err(format!( |
| 126 | + "Failed to connect to PostgreSQL with Entra auth: {e}" |
| 127 | + )) |
| 128 | + })?; |
| 129 | + Ok(Self { |
| 130 | + inner: Arc::new(provider), |
| 131 | + }) |
| 132 | + } |
| 133 | + |
| 134 | + /// Same as `connect_with_entra` but uses a custom schema for tenant |
| 135 | + /// isolation. The schema will be created if it does not exist. |
| 136 | + #[staticmethod] |
| 137 | + #[pyo3(signature = (host, port, database, user, schema, options=None))] |
| 138 | + fn connect_with_schema_and_entra( |
| 139 | + host: String, |
| 140 | + port: u16, |
| 141 | + database: String, |
| 142 | + user: String, |
| 143 | + schema: String, |
| 144 | + options: Option<PyPostgresEntraOptions>, |
| 145 | + ) -> PyResult<Self> { |
| 146 | + let entra_opts = options.unwrap_or_default().into_entra_auth_options(); |
| 147 | + let provider = TOKIO_RT |
| 148 | + .block_on(async { |
| 149 | + duroxide_pg::PostgresProvider::new_with_schema_and_entra( |
| 150 | + &host, |
| 151 | + port, |
| 152 | + &database, |
| 153 | + &user, |
| 154 | + Some(&schema), |
| 155 | + entra_opts, |
| 156 | + ) |
| 157 | + .await |
| 158 | + }) |
| 159 | + .map_err(|e| { |
| 160 | + pyo3::exceptions::PyRuntimeError::new_err(format!( |
| 161 | + "Failed to connect to PostgreSQL with Entra auth: {e}" |
| 162 | + )) |
| 163 | + })?; |
| 164 | + Ok(Self { |
| 165 | + inner: Arc::new(provider), |
| 166 | + }) |
| 167 | + } |
49 | 168 | } |
0 commit comments