Skip to content

Commit 45540e7

Browse files
committed
Replace Mutex with RwLock for OIDC client
This change switches from std::sync::Mutex to tokio::sync::RwLock for the OIDC client to avoid deadlocks. The read operations now use a shared lock while writes remain exclusive.
1 parent 5cf33e7 commit 45540e7

File tree

1 file changed

+22
-21
lines changed

1 file changed

+22
-21
lines changed

src/webserver/oidc.rs

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use openidconnect::{
3232
StandardTokenResponse,
3333
};
3434
use serde::{Deserialize, Serialize};
35-
use std::sync::{Mutex, MutexGuard};
35+
use tokio::sync::{RwLock, RwLockReadGuard};
3636

3737
use super::http_client::make_http_client;
3838

@@ -151,7 +151,7 @@ pub struct ClientWithTime {
151151
pub struct OidcState {
152152
pub config: OidcConfig,
153153
app_config: AppConfig,
154-
client: Mutex<ClientWithTime>,
154+
client: RwLock<ClientWithTime>,
155155
}
156156

157157
impl OidcState {
@@ -162,7 +162,7 @@ impl OidcState {
162162
Ok(Self {
163163
config: oidc_cfg,
164164
app_config,
165-
client: Mutex::new(ClientWithTime {
165+
client: RwLock::new(ClientWithTime {
166166
client,
167167
last_update: Instant::now(),
168168
}),
@@ -174,9 +174,10 @@ impl OidcState {
174174
log::error!("Failed to create HTTP client");
175175
return;
176176
};
177+
let mut write_guard = self.client.write().await;
177178
match build_oidc_client(&self.config, &http_client).await {
178179
Ok(client) => {
179-
*self.client.lock().expect("oidc client") = ClientWithTime {
180+
*write_guard = ClientWithTime {
180181
client,
181182
last_update: Instant::now(),
182183
};
@@ -188,26 +189,26 @@ impl OidcState {
188189
}
189190

190191
/// Gets a reference to the oidc client, potentially generating a new one if needed
191-
pub async fn get_client(&self) -> MutexGuard<'_, ClientWithTime> {
192+
pub async fn get_client(&self) -> RwLockReadGuard<'_, ClientWithTime> {
192193
{
193-
let client_lock = self.client.lock().expect("oidc client");
194+
let client_lock = self.client.read().await;
194195
if client_lock.last_update.elapsed() < OIDC_CLIENT_REFRESH_INTERVAL {
195196
return client_lock;
196197
}
197198
}
198199
log::debug!("OIDC client is older than {OIDC_CLIENT_REFRESH_INTERVAL:?}, refreshing...");
199200
self.refresh().await;
200-
self.client.lock().expect("oidc client")
201+
self.client.read().await
201202
}
202203

203204
/// Validate and decode the claims of an OIDC token, without refreshing the client.
204-
fn get_token_claims(
205+
async fn get_token_claims(
205206
&self,
206207
id_token: &OidcToken,
207208
state: Option<&OidcLoginState>,
208209
) -> anyhow::Result<OidcClaims> {
209210
// Do not refresh the client on every check
210-
let client = &self.client.lock().expect("oidc client").client;
211+
let client = &self.client.read().await.client;
211212
let verifier = self.config.create_id_token_verifier(client);
212213
let nonce_verifier = |nonce: Option<&Nonce>| check_nonce(nonce, state);
213214
let claims: OidcClaims = id_token
@@ -317,7 +318,7 @@ async fn handle_request(
317318
request: ServiceRequest,
318319
) -> actix_web::Result<MiddlewareResponse> {
319320
log::trace!("Started OIDC middleware request handling");
320-
let response = match get_authenticated_user_info(oidc_state, &request) {
321+
let response = match get_authenticated_user_info(oidc_state, &request).await {
321322
Ok(Some(claims)) => {
322323
if request.path() != SQLPAGE_REDIRECT_URI {
323324
log::trace!("Storing authenticated user info in request extensions: {claims:?}");
@@ -350,7 +351,7 @@ async fn handle_unauthenticated_request(
350351

351352
log::debug!("Redirecting to OIDC provider");
352353

353-
let response = build_auth_provider_redirect_response(oidc_state, &request);
354+
let response = build_auth_provider_redirect_response(oidc_state, &request).await;
354355
Ok(request.into_response(response))
355356
}
356357

@@ -364,7 +365,7 @@ async fn handle_oidc_callback(
364365
Err(e) => {
365366
log::error!("Failed to process OIDC callback with params {query_string}: {e}");
366367
oidc_state.refresh().await;
367-
let resp = build_auth_provider_redirect_response(oidc_state, &request);
368+
let resp = build_auth_provider_redirect_response(oidc_state, &request).await;
368369
Ok(request.into_response(resp))
369370
}
370371
}
@@ -440,7 +441,7 @@ async fn process_oidc_callback(
440441
let redirect_target = validate_redirect_url(state.initial_url);
441442
log::info!("Redirecting to {redirect_target} after a successful login");
442443
let mut response = build_redirect_response(redirect_target);
443-
set_auth_cookie(&mut response, &token_response, oidc_state)?;
444+
set_auth_cookie(&mut response, &token_response, oidc_state).await?;
444445
Ok(response)
445446
}
446447

@@ -458,7 +459,7 @@ async fn exchange_code_for_token(
458459
Ok(token_response)
459460
}
460461

461-
fn set_auth_cookie(
462+
async fn set_auth_cookie(
462463
response: &mut HttpResponse,
463464
token_response: &OidcTokenResponse,
464465
oidc_state: &OidcState,
@@ -469,7 +470,7 @@ fn set_auth_cookie(
469470
.id_token()
470471
.context("No ID token found in the token response. You may have specified an oauth2 provider that does not support OIDC.")?;
471472

472-
let claims = oidc_state.get_token_claims(id_token, None)?;
473+
let claims = oidc_state.get_token_claims(id_token, None).await?;
473474
let expiration = claims.expiration();
474475
let max_age_seconds = expiration.signed_duration_since(Utc::now()).num_seconds();
475476

@@ -494,11 +495,11 @@ fn set_auth_cookie(
494495
Ok(())
495496
}
496497

497-
fn build_auth_provider_redirect_response(
498+
async fn build_auth_provider_redirect_response(
498499
oidc_state: &OidcState,
499500
request: &ServiceRequest,
500501
) -> HttpResponse {
501-
let AuthUrl { url, params } = build_auth_url(oidc_state);
502+
let AuthUrl { url, params } = build_auth_url(oidc_state).await;
502503
let state_cookie = create_state_cookie(request, params);
503504
HttpResponse::TemporaryRedirect()
504505
.append_header(("Location", url.to_string()))
@@ -513,7 +514,7 @@ fn build_redirect_response(target_url: String) -> HttpResponse {
513514
}
514515

515516
/// Returns the claims from the ID token in the `SQLPage` auth cookie.
516-
fn get_authenticated_user_info(
517+
async fn get_authenticated_user_info(
517518
oidc_state: &OidcState,
518519
request: &ServiceRequest,
519520
) -> anyhow::Result<Option<OidcClaims>> {
@@ -525,7 +526,7 @@ fn get_authenticated_user_info(
525526
.with_context(|| format!("Invalid SQLPage auth cookie: {cookie_value:?}"))?;
526527

527528
let state = get_state_from_cookie(request)?;
528-
let claims = oidc_state.get_token_claims(&id_token, Some(&state))?;
529+
let claims = oidc_state.get_token_claims(&id_token, Some(&state)).await?;
529530
log::debug!("The current user is: {claims:?}");
530531
Ok(Some(claims))
531532
}
@@ -692,11 +693,11 @@ struct AuthUrlParams {
692693
nonce: Nonce,
693694
}
694695

695-
fn build_auth_url(oidc_state: &OidcState) -> AuthUrl {
696+
async fn build_auth_url(oidc_state: &OidcState) -> AuthUrl {
696697
let nonce_source = Nonce::new_random();
697698
let hashed_nonce = Nonce::new(hash_nonce(&nonce_source));
698699
let scopes = &oidc_state.config.scopes;
699-
let client_lock = oidc_state.client.lock().unwrap();
700+
let client_lock = oidc_state.get_client().await;
700701
let (url, csrf_token, _nonce) = client_lock
701702
.client
702703
.authorize_url(

0 commit comments

Comments
 (0)