@@ -32,7 +32,7 @@ use openidconnect::{
3232 StandardTokenResponse ,
3333} ;
3434use serde:: { Deserialize , Serialize } ;
35- use std :: sync:: { Mutex , MutexGuard } ;
35+ use tokio :: sync:: { RwLock , RwLockReadGuard } ;
3636
3737use super :: http_client:: make_http_client;
3838
@@ -151,7 +151,7 @@ pub struct ClientWithTime {
151151pub struct OidcState {
152152 pub config : OidcConfig ,
153153 app_config : AppConfig ,
154- client : Mutex < ClientWithTime > ,
154+ client : RwLock < ClientWithTime > ,
155155}
156156
157157impl OidcState {
@@ -162,7 +162,7 @@ impl OidcState {
162162 Ok ( Self {
163163 config : oidc_cfg,
164164 app_config,
165- client : Mutex :: new ( ClientWithTime {
165+ client : RwLock :: new ( ClientWithTime {
166166 client,
167167 last_update : Instant :: now ( ) ,
168168 } ) ,
@@ -174,9 +174,10 @@ impl OidcState {
174174 log:: error!( "Failed to create HTTP client" ) ;
175175 return ;
176176 } ;
177+ let mut write_guard = self . client . write ( ) . await ;
177178 match build_oidc_client ( & self . config , & http_client) . await {
178179 Ok ( client) => {
179- * self . client . lock ( ) . expect ( "oidc client" ) = ClientWithTime {
180+ * write_guard = ClientWithTime {
180181 client,
181182 last_update : Instant :: now ( ) ,
182183 } ;
@@ -188,26 +189,26 @@ impl OidcState {
188189 }
189190
190191 /// Gets a reference to the oidc client, potentially generating a new one if needed
191- pub async fn get_client ( & self ) -> MutexGuard < ' _ , ClientWithTime > {
192+ pub async fn get_client ( & self ) -> RwLockReadGuard < ' _ , ClientWithTime > {
192193 {
193- let client_lock = self . client . lock ( ) . expect ( "oidc client" ) ;
194+ let client_lock = self . client . read ( ) . await ;
194195 if client_lock. last_update . elapsed ( ) < OIDC_CLIENT_REFRESH_INTERVAL {
195196 return client_lock;
196197 }
197198 }
198199 log:: debug!( "OIDC client is older than {OIDC_CLIENT_REFRESH_INTERVAL:?}, refreshing..." ) ;
199200 self . refresh ( ) . await ;
200- self . client . lock ( ) . expect ( "oidc client" )
201+ self . client . read ( ) . await
201202 }
202203
203204 /// Validate and decode the claims of an OIDC token, without refreshing the client.
204- fn get_token_claims (
205+ async fn get_token_claims (
205206 & self ,
206207 id_token : & OidcToken ,
207208 state : Option < & OidcLoginState > ,
208209 ) -> anyhow:: Result < OidcClaims > {
209210 // Do not refresh the client on every check
210- let client = & self . client . lock ( ) . expect ( "oidc client" ) . client ;
211+ let client = & self . client . read ( ) . await . client ;
211212 let verifier = self . config . create_id_token_verifier ( client) ;
212213 let nonce_verifier = |nonce : Option < & Nonce > | check_nonce ( nonce, state) ;
213214 let claims: OidcClaims = id_token
@@ -317,7 +318,7 @@ async fn handle_request(
317318 request : ServiceRequest ,
318319) -> actix_web:: Result < MiddlewareResponse > {
319320 log:: trace!( "Started OIDC middleware request handling" ) ;
320- let response = match get_authenticated_user_info ( oidc_state, & request) {
321+ let response = match get_authenticated_user_info ( oidc_state, & request) . await {
321322 Ok ( Some ( claims) ) => {
322323 if request. path ( ) != SQLPAGE_REDIRECT_URI {
323324 log:: trace!( "Storing authenticated user info in request extensions: {claims:?}" ) ;
@@ -350,7 +351,7 @@ async fn handle_unauthenticated_request(
350351
351352 log:: debug!( "Redirecting to OIDC provider" ) ;
352353
353- let response = build_auth_provider_redirect_response ( oidc_state, & request) ;
354+ let response = build_auth_provider_redirect_response ( oidc_state, & request) . await ;
354355 Ok ( request. into_response ( response) )
355356}
356357
@@ -364,7 +365,7 @@ async fn handle_oidc_callback(
364365 Err ( e) => {
365366 log:: error!( "Failed to process OIDC callback with params {query_string}: {e}" ) ;
366367 oidc_state. refresh ( ) . await ;
367- let resp = build_auth_provider_redirect_response ( oidc_state, & request) ;
368+ let resp = build_auth_provider_redirect_response ( oidc_state, & request) . await ;
368369 Ok ( request. into_response ( resp) )
369370 }
370371 }
@@ -440,7 +441,7 @@ async fn process_oidc_callback(
440441 let redirect_target = validate_redirect_url ( state. initial_url ) ;
441442 log:: info!( "Redirecting to {redirect_target} after a successful login" ) ;
442443 let mut response = build_redirect_response ( redirect_target) ;
443- set_auth_cookie ( & mut response, & token_response, oidc_state) ?;
444+ set_auth_cookie ( & mut response, & token_response, oidc_state) . await ?;
444445 Ok ( response)
445446}
446447
@@ -458,7 +459,7 @@ async fn exchange_code_for_token(
458459 Ok ( token_response)
459460}
460461
461- fn set_auth_cookie (
462+ async fn set_auth_cookie (
462463 response : & mut HttpResponse ,
463464 token_response : & OidcTokenResponse ,
464465 oidc_state : & OidcState ,
@@ -469,7 +470,7 @@ fn set_auth_cookie(
469470 . id_token ( )
470471 . context ( "No ID token found in the token response. You may have specified an oauth2 provider that does not support OIDC." ) ?;
471472
472- let claims = oidc_state. get_token_claims ( id_token, None ) ?;
473+ let claims = oidc_state. get_token_claims ( id_token, None ) . await ?;
473474 let expiration = claims. expiration ( ) ;
474475 let max_age_seconds = expiration. signed_duration_since ( Utc :: now ( ) ) . num_seconds ( ) ;
475476
@@ -494,11 +495,11 @@ fn set_auth_cookie(
494495 Ok ( ( ) )
495496}
496497
497- fn build_auth_provider_redirect_response (
498+ async fn build_auth_provider_redirect_response (
498499 oidc_state : & OidcState ,
499500 request : & ServiceRequest ,
500501) -> HttpResponse {
501- let AuthUrl { url, params } = build_auth_url ( oidc_state) ;
502+ let AuthUrl { url, params } = build_auth_url ( oidc_state) . await ;
502503 let state_cookie = create_state_cookie ( request, params) ;
503504 HttpResponse :: TemporaryRedirect ( )
504505 . append_header ( ( "Location" , url. to_string ( ) ) )
@@ -513,7 +514,7 @@ fn build_redirect_response(target_url: String) -> HttpResponse {
513514}
514515
515516/// Returns the claims from the ID token in the `SQLPage` auth cookie.
516- fn get_authenticated_user_info (
517+ async fn get_authenticated_user_info (
517518 oidc_state : & OidcState ,
518519 request : & ServiceRequest ,
519520) -> anyhow:: Result < Option < OidcClaims > > {
@@ -525,7 +526,7 @@ fn get_authenticated_user_info(
525526 . with_context ( || format ! ( "Invalid SQLPage auth cookie: {cookie_value:?}" ) ) ?;
526527
527528 let state = get_state_from_cookie ( request) ?;
528- let claims = oidc_state. get_token_claims ( & id_token, Some ( & state) ) ?;
529+ let claims = oidc_state. get_token_claims ( & id_token, Some ( & state) ) . await ?;
529530 log:: debug!( "The current user is: {claims:?}" ) ;
530531 Ok ( Some ( claims) )
531532}
@@ -692,11 +693,11 @@ struct AuthUrlParams {
692693 nonce : Nonce ,
693694}
694695
695- fn build_auth_url ( oidc_state : & OidcState ) -> AuthUrl {
696+ async fn build_auth_url ( oidc_state : & OidcState ) -> AuthUrl {
696697 let nonce_source = Nonce :: new_random ( ) ;
697698 let hashed_nonce = Nonce :: new ( hash_nonce ( & nonce_source) ) ;
698699 let scopes = & oidc_state. config . scopes ;
699- let client_lock = oidc_state. client . lock ( ) . unwrap ( ) ;
700+ let client_lock = oidc_state. get_client ( ) . await ;
700701 let ( url, csrf_token, _nonce) = client_lock
701702 . client
702703 . authorize_url (
0 commit comments