Skip to content

Commit 5cf33e7

Browse files
committed
refactor: restructure OIDC middleware for better modularity and
ownership - Extract request handling logic into standalone async functions - Replace service ownership with Rc<S> for better memory management - Simplify get_client() return type to use MutexGuard directly - Add MiddlewareResponse enum to clarify response handling flow - Move public path checking to beginning of request pipeline - Reorganize code to separate concerns and improve readability
1 parent 5bd3ac8 commit 5cf33e7

File tree

1 file changed

+73
-71
lines changed

1 file changed

+73
-71
lines changed

src/webserver/oidc.rs

Lines changed: 73 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::collections::HashSet;
22
use std::future::ready;
3-
use std::ops::Deref;
3+
use std::rc::Rc;
44
use std::time::{Duration, Instant};
55
use std::{future::Future, pin::Pin, str::FromStr, sync::Arc};
66

@@ -32,7 +32,7 @@ use openidconnect::{
3232
StandardTokenResponse,
3333
};
3434
use serde::{Deserialize, Serialize};
35-
use std::sync::Mutex;
35+
use std::sync::{Mutex, MutexGuard};
3636

3737
use super::http_client::make_http_client;
3838

@@ -188,7 +188,7 @@ impl OidcState {
188188
}
189189

190190
/// Gets a reference to the oidc client, potentially generating a new one if needed
191-
pub async fn get_client(&self) -> impl Deref<Target = ClientWithTime> + '_ {
191+
pub async fn get_client(&self) -> MutexGuard<'_, ClientWithTime> {
192192
{
193193
let client_lock = self.client.lock().expect("oidc client");
194194
if client_lock.last_update.elapsed() < OIDC_CLIENT_REFRESH_INTERVAL {
@@ -290,7 +290,7 @@ where
290290

291291
#[derive(Clone)]
292292
pub struct OidcService<S> {
293-
service: S,
293+
service: Rc<S>,
294294
oidc_state: Arc<OidcState>,
295295
}
296296

@@ -301,56 +301,72 @@ where
301301
{
302302
pub fn new(service: S, oidc_state: Arc<OidcState>) -> Self {
303303
Self {
304-
service,
304+
service: Rc::new(service),
305305
oidc_state,
306306
}
307307
}
308+
}
308309

309-
fn handle_unauthenticated_request(
310-
&self,
311-
request: ServiceRequest,
312-
) -> LocalBoxFuture<Result<ServiceResponse<BoxBody>, Error>> {
313-
log::debug!("Handling unauthenticated request to {}", request.path());
314-
if request.path() == SQLPAGE_REDIRECT_URI {
315-
log::debug!("The request is the OIDC callback");
316-
return self.handle_oidc_callback(request);
317-
}
310+
enum MiddlewareResponse {
311+
Forward(ServiceRequest),
312+
Respond(ServiceResponse),
313+
}
318314

319-
if self.oidc_state.config.is_public_path(request.path()) {
320-
log::debug!(
321-
"The request path {} is not in a public path, skipping OIDC authentication",
322-
request.path()
323-
);
324-
return Box::pin(self.service.call(request));
315+
async fn handle_request(
316+
oidc_state: &OidcState,
317+
request: ServiceRequest,
318+
) -> actix_web::Result<MiddlewareResponse> {
319+
log::trace!("Started OIDC middleware request handling");
320+
let response = match get_authenticated_user_info(oidc_state, &request) {
321+
Ok(Some(claims)) => {
322+
if request.path() != SQLPAGE_REDIRECT_URI {
323+
log::trace!("Storing authenticated user info in request extensions: {claims:?}");
324+
request.extensions_mut().insert(claims);
325+
return Ok(MiddlewareResponse::Forward(request));
326+
}
327+
handle_authenticated_oidc_callback(request).await
325328
}
329+
Ok(None) => {
330+
log::trace!("No authenticated user found");
331+
handle_unauthenticated_request(oidc_state, request).await
332+
}
333+
Err(e) => {
334+
log::debug!("An auth cookie is present but could not be verified. Redirecting to OIDC provider to re-authenticate. {e:?}");
335+
handle_unauthenticated_request(oidc_state, request).await
336+
}
337+
};
338+
response.map(MiddlewareResponse::Respond)
339+
}
326340

327-
log::debug!("Redirecting to OIDC provider");
328-
329-
let oidc_state = Arc::clone(&self.oidc_state);
330-
Box::pin(async move {
331-
let response = build_auth_provider_redirect_response(&oidc_state, &request);
332-
Ok(request.into_response(response))
333-
})
341+
async fn handle_unauthenticated_request(
342+
oidc_state: &OidcState,
343+
request: ServiceRequest,
344+
) -> Result<ServiceResponse<BoxBody>, Error> {
345+
log::debug!("Handling unauthenticated request to {}", request.path());
346+
if request.path() == SQLPAGE_REDIRECT_URI {
347+
log::debug!("The request is the OIDC callback");
348+
return handle_oidc_callback(oidc_state, request).await;
334349
}
335350

336-
fn handle_oidc_callback(
337-
&self,
338-
request: ServiceRequest,
339-
) -> LocalBoxFuture<Result<ServiceResponse<BoxBody>, Error>> {
340-
let oidc_state = Arc::clone(&self.oidc_state);
351+
log::debug!("Redirecting to OIDC provider");
341352

342-
Box::pin(async move {
343-
let query_string = request.query_string();
344-
match process_oidc_callback(&oidc_state, query_string, &request).await {
345-
Ok(response) => Ok(request.into_response(response)),
346-
Err(e) => {
347-
log::error!("Failed to process OIDC callback with params {query_string}: {e}");
348-
oidc_state.refresh().await;
349-
let resp = build_auth_provider_redirect_response(&oidc_state, &request);
350-
Ok(request.into_response(resp))
351-
}
352-
}
353-
})
353+
let response = build_auth_provider_redirect_response(oidc_state, &request);
354+
Ok(request.into_response(response))
355+
}
356+
357+
async fn handle_oidc_callback(
358+
oidc_state: &OidcState,
359+
request: ServiceRequest,
360+
) -> Result<ServiceResponse<BoxBody>, Error> {
361+
let query_string = request.query_string();
362+
match process_oidc_callback(oidc_state, query_string, &request).await {
363+
Ok(response) => Ok(request.into_response(response)),
364+
Err(e) => {
365+
log::error!("Failed to process OIDC callback with params {query_string}: {e}");
366+
oidc_state.refresh().await;
367+
let resp = build_auth_provider_redirect_response(oidc_state, &request);
368+
Ok(request.into_response(resp))
369+
}
354370
}
355371
}
356372

@@ -369,7 +385,7 @@ fn handle_authenticated_oidc_callback(
369385

370386
impl<S> Service<ServiceRequest> for OidcService<S>
371387
where
372-
S: Service<ServiceRequest, Response = ServiceResponse<BoxBody>, Error = Error>,
388+
S: Service<ServiceRequest, Response = ServiceResponse<BoxBody>, Error = Error> + 'static,
373389
S::Future: 'static,
374390
{
375391
type Response = ServiceResponse<BoxBody>;
@@ -379,32 +395,18 @@ where
379395
forward_ready!(service);
380396

381397
fn call(&self, request: ServiceRequest) -> Self::Future {
382-
log::trace!("Started OIDC middleware request handling");
383-
384-
match get_authenticated_user_info(&self.oidc_state, &request) {
385-
Ok(Some(claims)) => {
386-
if request.path() == SQLPAGE_REDIRECT_URI {
387-
return handle_authenticated_oidc_callback(request);
388-
}
389-
log::trace!("Storing authenticated user info in request extensions: {claims:?}");
390-
request.extensions_mut().insert(claims);
391-
}
392-
Ok(None) => {
393-
log::trace!("No authenticated user found");
394-
return self.handle_unauthenticated_request(request);
395-
}
396-
Err(e) => {
397-
log::debug!(
398-
"{:?}",
399-
e.context(
400-
"An auth cookie is present but could not be verified. \
401-
Redirecting to OIDC provider to re-authenticate."
402-
)
403-
);
404-
return self.handle_unauthenticated_request(request);
405-
}
398+
if self.oidc_state.config.is_public_path(request.path()) {
399+
return Box::pin(self.service.call(request));
406400
}
407-
Box::pin(self.service.call(request))
401+
let srv = Rc::clone(&self.service);
402+
let oidc_state = Arc::clone(&self.oidc_state);
403+
Box::pin(async move {
404+
match handle_request(&oidc_state, request).await {
405+
Ok(MiddlewareResponse::Respond(response)) => Ok(response),
406+
Ok(MiddlewareResponse::Forward(request)) => srv.call(request).await,
407+
Err(err) => Err(err),
408+
}
409+
})
408410
}
409411
}
410412

@@ -512,7 +514,7 @@ fn build_redirect_response(target_url: String) -> HttpResponse {
512514

513515
/// Returns the claims from the ID token in the `SQLPage` auth cookie.
514516
fn get_authenticated_user_info(
515-
oidc_state: &Arc<OidcState>,
517+
oidc_state: &OidcState,
516518
request: &ServiceRequest,
517519
) -> anyhow::Result<Option<OidcClaims>> {
518520
let Some(cookie) = request.cookie(SQLPAGE_AUTH_COOKIE_NAME) else {

0 commit comments

Comments
 (0)