Skip to content

Commit f487220

Browse files
cursoragentlovasoa
andcommitted
Refactor OIDC client management and remove unnecessary HTTP client
Co-authored-by: contact <contact@ophir.dev>
1 parent b8c51e2 commit f487220

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

src/webserver/oidc.rs

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -167,19 +167,28 @@ fn get_app_host(config: &AppConfig) -> String {
167167
pub struct OidcState {
168168
pub config: Arc<OidcConfig>,
169169
cached_provider: Arc<RwLock<CachedProvider>>,
170-
http_client: Arc<Client>,
171170
}
172171

173172
impl OidcState {
173+
/// Get the current OIDC client, checking if cache is stale but not attempting refresh
174+
pub fn get_client(&self) -> OidcClient {
175+
// For now, we'll use a simple approach - get the current client
176+
// In a production system, you might want to check if cache is stale
177+
// and trigger an async refresh task
178+
futures_util::executor::block_on(async {
179+
self.cached_provider.read().await.client.clone()
180+
})
181+
}
182+
174183
/// Get the current OIDC client, refreshing if stale and possible
175-
pub async fn get_client(&self) -> OidcClient {
184+
pub async fn get_client_with_refresh(&self, app_config: &AppConfig) -> OidcClient {
176185
// Try to refresh if cache is stale and we haven't tried recently
177186
{
178187
let cache = self.cached_provider.read().await;
179188
if cache.is_stale() && cache.can_refresh() {
180189
// Release read lock before attempting refresh
181190
drop(cache);
182-
if let Err(e) = self.refresh_provider().await {
191+
if let Err(e) = self.refresh_provider(app_config).await {
183192
log::warn!("Failed to refresh OIDC provider: {}", e);
184193
}
185194
}
@@ -189,7 +198,7 @@ impl OidcState {
189198
}
190199

191200
/// Refresh provider metadata and client from the OIDC provider
192-
async fn refresh_provider(&self) -> anyhow::Result<()> {
201+
async fn refresh_provider(&self, app_config: &AppConfig) -> anyhow::Result<()> {
193202
let mut cache = self.cached_provider.write().await;
194203

195204
// Double-check we can refresh (another thread might have just done it)
@@ -204,8 +213,9 @@ impl OidcState {
204213
self.config.issuer_url
205214
);
206215

216+
let http_client = make_http_client(app_config)?;
207217
let new_metadata =
208-
discover_provider_metadata(&self.http_client, self.config.issuer_url.clone()).await?;
218+
discover_provider_metadata(&http_client, self.config.issuer_url.clone()).await?;
209219
let new_client = make_oidc_client(&self.config, new_metadata.clone())?;
210220

211221
cache.update(new_client, new_metadata);
@@ -224,7 +234,7 @@ pub async fn initialize_oidc_state(
224234
Err(Some(e)) => return Err(anyhow::anyhow!(e)),
225235
};
226236

227-
let http_client = Arc::new(make_http_client(app_config)?);
237+
let http_client = make_http_client(app_config)?;
228238

229239
// Initial metadata discovery
230240
let provider_metadata =
@@ -234,7 +244,6 @@ pub async fn initialize_oidc_state(
234244
let oidc_state = Arc::new(OidcState {
235245
config: oidc_cfg,
236246
cached_provider: Arc::new(RwLock::new(CachedProvider::new(client, provider_metadata))),
237-
http_client,
238247
});
239248

240249
Ok(Some(oidc_state))
@@ -329,7 +338,7 @@ where
329338
}
330339

331340
log::debug!("Redirecting to OIDC provider");
332-
let client = oidc_state.get_client().await;
341+
let client = oidc_state.get_client();
333342
let response = build_auth_provider_redirect_response(&client, &oidc_state.config, &request);
334343
Ok(request.into_response(response))
335344
}
@@ -338,7 +347,7 @@ async fn handle_oidc_callback(
338347
oidc_state: Arc<OidcState>,
339348
request: ServiceRequest,
340349
) -> Result<ServiceResponse<BoxBody>, Error> {
341-
let oidc_client = oidc_state.get_client().await;
350+
let oidc_client = oidc_state.get_client();
342351
let query_string = request.query_string();
343352
match process_oidc_callback(&oidc_client, query_string, &request).await {
344353
Ok(response) => Ok(request.into_response(response)),
@@ -353,7 +362,7 @@ async fn handle_oidc_callback(
353362

354363
impl<S> Service<ServiceRequest> for OidcService<S>
355364
where
356-
S: Service<ServiceRequest, Response = ServiceResponse<BoxBody>, Error = Error>,
365+
S: Service<ServiceRequest, Response = ServiceResponse<BoxBody>, Error = Error> + Clone,
357366
S::Future: 'static,
358367
{
359368
type Response = ServiceResponse<BoxBody>;
@@ -369,7 +378,7 @@ where
369378
let service = self.service.clone();
370379

371380
Box::pin(async move {
372-
let oidc_client = oidc_state.get_client().await;
381+
let oidc_client = oidc_state.get_client();
373382
match get_authenticated_user_info(&oidc_client, &request) {
374383
Ok(Some(claims)) => {
375384
log::trace!(

0 commit comments

Comments
 (0)