11use std:: collections:: HashSet ;
22use std:: future:: ready;
3- use std:: ops :: Deref ;
3+ use std:: rc :: Rc ;
44use std:: time:: { Duration , Instant } ;
55use std:: { future:: Future , pin:: Pin , str:: FromStr , sync:: Arc } ;
66
@@ -32,7 +32,7 @@ use openidconnect::{
3232 StandardTokenResponse ,
3333} ;
3434use serde:: { Deserialize , Serialize } ;
35- use std:: sync:: Mutex ;
35+ use std:: sync:: { Mutex , MutexGuard } ;
3636
3737use super :: http_client:: make_http_client;
3838
@@ -188,7 +188,7 @@ impl OidcState {
188188 }
189189
190190 /// Gets a reference to the oidc client, potentially generating a new one if needed
191- pub async fn get_client ( & self ) -> impl Deref < Target = ClientWithTime > + ' _ {
191+ pub async fn get_client ( & self ) -> MutexGuard < ' _ , ClientWithTime > {
192192 {
193193 let client_lock = self . client . lock ( ) . expect ( "oidc client" ) ;
194194 if client_lock. last_update . elapsed ( ) < OIDC_CLIENT_REFRESH_INTERVAL {
@@ -290,7 +290,7 @@ where
290290
291291#[ derive( Clone ) ]
292292pub struct OidcService < S > {
293- service : S ,
293+ service : Rc < S > ,
294294 oidc_state : Arc < OidcState > ,
295295}
296296
@@ -301,56 +301,72 @@ where
301301{
302302 pub fn new ( service : S , oidc_state : Arc < OidcState > ) -> Self {
303303 Self {
304- service,
304+ service : Rc :: new ( service ) ,
305305 oidc_state,
306306 }
307307 }
308+ }
308309
309- fn handle_unauthenticated_request (
310- & self ,
311- request : ServiceRequest ,
312- ) -> LocalBoxFuture < Result < ServiceResponse < BoxBody > , Error > > {
313- log:: debug!( "Handling unauthenticated request to {}" , request. path( ) ) ;
314- if request. path ( ) == SQLPAGE_REDIRECT_URI {
315- log:: debug!( "The request is the OIDC callback" ) ;
316- return self . handle_oidc_callback ( request) ;
317- }
310+ enum MiddlewareResponse {
311+ Forward ( ServiceRequest ) ,
312+ Respond ( ServiceResponse ) ,
313+ }
318314
319- if self . oidc_state . config . is_public_path ( request. path ( ) ) {
320- log:: debug!(
321- "The request path {} is not in a public path, skipping OIDC authentication" ,
322- request. path( )
323- ) ;
324- return Box :: pin ( self . service . call ( request) ) ;
315+ async fn handle_request (
316+ oidc_state : & OidcState ,
317+ request : ServiceRequest ,
318+ ) -> actix_web:: Result < MiddlewareResponse > {
319+ log:: trace!( "Started OIDC middleware request handling" ) ;
320+ let response = match get_authenticated_user_info ( oidc_state, & request) {
321+ Ok ( Some ( claims) ) => {
322+ if request. path ( ) != SQLPAGE_REDIRECT_URI {
323+ log:: trace!( "Storing authenticated user info in request extensions: {claims:?}" ) ;
324+ request. extensions_mut ( ) . insert ( claims) ;
325+ return Ok ( MiddlewareResponse :: Forward ( request) ) ;
326+ }
327+ handle_authenticated_oidc_callback ( request) . await
325328 }
329+ Ok ( None ) => {
330+ log:: trace!( "No authenticated user found" ) ;
331+ handle_unauthenticated_request ( oidc_state, request) . await
332+ }
333+ Err ( e) => {
334+ log:: debug!( "An auth cookie is present but could not be verified. Redirecting to OIDC provider to re-authenticate. {e:?}" ) ;
335+ handle_unauthenticated_request ( oidc_state, request) . await
336+ }
337+ } ;
338+ response. map ( MiddlewareResponse :: Respond )
339+ }
326340
327- log:: debug!( "Redirecting to OIDC provider" ) ;
328-
329- let oidc_state = Arc :: clone ( & self . oidc_state ) ;
330- Box :: pin ( async move {
331- let response = build_auth_provider_redirect_response ( & oidc_state, & request) ;
332- Ok ( request. into_response ( response) )
333- } )
341+ async fn handle_unauthenticated_request (
342+ oidc_state : & OidcState ,
343+ request : ServiceRequest ,
344+ ) -> Result < ServiceResponse < BoxBody > , Error > {
345+ log:: debug!( "Handling unauthenticated request to {}" , request. path( ) ) ;
346+ if request. path ( ) == SQLPAGE_REDIRECT_URI {
347+ log:: debug!( "The request is the OIDC callback" ) ;
348+ return handle_oidc_callback ( oidc_state, request) . await ;
334349 }
335350
336- fn handle_oidc_callback (
337- & self ,
338- request : ServiceRequest ,
339- ) -> LocalBoxFuture < Result < ServiceResponse < BoxBody > , Error > > {
340- let oidc_state = Arc :: clone ( & self . oidc_state ) ;
351+ log:: debug!( "Redirecting to OIDC provider" ) ;
341352
342- Box :: pin ( async move {
343- let query_string = request. query_string ( ) ;
344- match process_oidc_callback ( & oidc_state, query_string, & request) . await {
345- Ok ( response) => Ok ( request. into_response ( response) ) ,
346- Err ( e) => {
347- log:: error!( "Failed to process OIDC callback with params {query_string}: {e}" ) ;
348- oidc_state. refresh ( ) . await ;
349- let resp = build_auth_provider_redirect_response ( & oidc_state, & request) ;
350- Ok ( request. into_response ( resp) )
351- }
352- }
353- } )
353+ let response = build_auth_provider_redirect_response ( oidc_state, & request) ;
354+ Ok ( request. into_response ( response) )
355+ }
356+
357+ async fn handle_oidc_callback (
358+ oidc_state : & OidcState ,
359+ request : ServiceRequest ,
360+ ) -> Result < ServiceResponse < BoxBody > , Error > {
361+ let query_string = request. query_string ( ) ;
362+ match process_oidc_callback ( oidc_state, query_string, & request) . await {
363+ Ok ( response) => Ok ( request. into_response ( response) ) ,
364+ Err ( e) => {
365+ log:: error!( "Failed to process OIDC callback with params {query_string}: {e}" ) ;
366+ oidc_state. refresh ( ) . await ;
367+ let resp = build_auth_provider_redirect_response ( oidc_state, & request) ;
368+ Ok ( request. into_response ( resp) )
369+ }
354370 }
355371}
356372
@@ -369,7 +385,7 @@ fn handle_authenticated_oidc_callback(
369385
370386impl < S > Service < ServiceRequest > for OidcService < S >
371387where
372- S : Service < ServiceRequest , Response = ServiceResponse < BoxBody > , Error = Error > ,
388+ S : Service < ServiceRequest , Response = ServiceResponse < BoxBody > , Error = Error > + ' static ,
373389 S :: Future : ' static ,
374390{
375391 type Response = ServiceResponse < BoxBody > ;
@@ -379,32 +395,18 @@ where
379395 forward_ready ! ( service) ;
380396
381397 fn call ( & self , request : ServiceRequest ) -> Self :: Future {
382- log:: trace!( "Started OIDC middleware request handling" ) ;
383-
384- match get_authenticated_user_info ( & self . oidc_state , & request) {
385- Ok ( Some ( claims) ) => {
386- if request. path ( ) == SQLPAGE_REDIRECT_URI {
387- return handle_authenticated_oidc_callback ( request) ;
388- }
389- log:: trace!( "Storing authenticated user info in request extensions: {claims:?}" ) ;
390- request. extensions_mut ( ) . insert ( claims) ;
391- }
392- Ok ( None ) => {
393- log:: trace!( "No authenticated user found" ) ;
394- return self . handle_unauthenticated_request ( request) ;
395- }
396- Err ( e) => {
397- log:: debug!(
398- "{:?}" ,
399- e. context(
400- "An auth cookie is present but could not be verified. \
401- Redirecting to OIDC provider to re-authenticate."
402- )
403- ) ;
404- return self . handle_unauthenticated_request ( request) ;
405- }
398+ if self . oidc_state . config . is_public_path ( request. path ( ) ) {
399+ return Box :: pin ( self . service . call ( request) ) ;
406400 }
407- Box :: pin ( self . service . call ( request) )
401+ let srv = Rc :: clone ( & self . service ) ;
402+ let oidc_state = Arc :: clone ( & self . oidc_state ) ;
403+ Box :: pin ( async move {
404+ match handle_request ( & oidc_state, request) . await {
405+ Ok ( MiddlewareResponse :: Respond ( response) ) => Ok ( response) ,
406+ Ok ( MiddlewareResponse :: Forward ( request) ) => srv. call ( request) . await ,
407+ Err ( err) => Err ( err) ,
408+ }
409+ } )
408410 }
409411}
410412
@@ -512,7 +514,7 @@ fn build_redirect_response(target_url: String) -> HttpResponse {
512514
513515/// Returns the claims from the ID token in the `SQLPage` auth cookie.
514516fn get_authenticated_user_info (
515- oidc_state : & Arc < OidcState > ,
517+ oidc_state : & OidcState ,
516518 request : & ServiceRequest ,
517519) -> anyhow:: Result < Option < OidcClaims > > {
518520 let Some ( cookie) = request. cookie ( SQLPAGE_AUTH_COOKIE_NAME ) else {
0 commit comments