Skip to content

Commit d5f7a84

Browse files
lovasoaclaude
andcommitted
rethink OIDC state: std::sync::RwLock<Arc<Snapshot>> for lock-free reads
Replace tokio::sync::RwLock<ClientWithTime> with std::sync::RwLock<Arc<OidcSnapshot>>. The std lock makes it structurally impossible to hold across await points. Readers clone an Arc (nanoseconds) and use it freely — no lock contention. Many previously-async functions become synchronous (get_token_claims, build_auth_url, handle_unauthenticated_request, handle_oidc_logout, etc). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 817674c commit d5f7a84

1 file changed

Lines changed: 82 additions & 89 deletions

File tree

src/webserver/oidc.rs

Lines changed: 82 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ use openidconnect::{
3434
StandardTokenResponse,
3535
};
3636
use serde::{Deserialize, Serialize};
37-
use tokio::sync::{RwLock, RwLockReadGuard};
3837

3938
use super::error::anyhow_err_to_actix_resp;
4039
use super::http_client::make_http_client;
@@ -189,16 +188,21 @@ fn get_app_host(config: &AppConfig) -> String {
189188
host
190189
}
191190

192-
pub struct ClientWithTime {
191+
/// A point-in-time snapshot of the OIDC provider's client and metadata.
192+
/// Cheaply cloneable via Arc — callers never hold a lock while using this.
193+
struct OidcSnapshot {
193194
client: OidcClient,
194195
end_session_endpoint: Option<EndSessionUrl>,
195-
last_update: Instant,
196+
created_at: Instant,
196197
}
197198

198199
pub struct OidcState {
199200
pub config: OidcConfig,
200-
client: RwLock<ClientWithTime>,
201-
refreshing: std::sync::atomic::AtomicBool,
201+
/// Current snapshot. The lock is only held for the instant
202+
/// needed to clone/swap the Arc — never across await points.
203+
snapshot: std::sync::RwLock<Arc<OidcSnapshot>>,
204+
/// Prevents concurrent background refreshes.
205+
refresh_in_progress: std::sync::atomic::AtomicBool,
202206
}
203207

204208
impl OidcState {
@@ -208,101 +212,91 @@ impl OidcState {
208212

209213
Ok(Self {
210214
config: oidc_cfg,
211-
client: RwLock::new(ClientWithTime {
215+
snapshot: std::sync::RwLock::new(Arc::new(OidcSnapshot {
212216
client,
213217
end_session_endpoint,
214-
last_update: Instant::now(),
215-
}),
216-
refreshing: std::sync::atomic::AtomicBool::new(false),
218+
created_at: Instant::now(),
219+
})),
220+
refresh_in_progress: std::sync::atomic::AtomicBool::new(false),
217221
})
218222
}
219223

220-
/// Spawns a background task to refresh the OIDC client from the provider
221-
/// metadata URL if it hasn't been refreshed within `max_age`.
222-
/// Returns immediately without blocking the caller.
223-
/// Multiple concurrent calls are deduplicated via an atomic flag.
224-
fn refresh_in_background(self: &Arc<Self>, http_client: &Client, max_age: Duration) {
224+
/// Returns the current snapshot. Never blocks in practice.
225+
fn snapshot(&self) -> Arc<OidcSnapshot> {
226+
self.snapshot.read().unwrap().clone()
227+
}
228+
229+
/// If the snapshot is older than `max_age` and no refresh is already running,
230+
/// spawns a background task to fetch new provider metadata.
231+
/// Returns immediately — never blocks the caller on I/O.
232+
pub fn maybe_refresh(self: &Arc<Self>, http_client: &Client, max_age: Duration) {
225233
use std::sync::atomic::Ordering;
226-
let Ok(last_update) = self.client.try_read().map(|g| g.last_update) else {
227-
return; // write lock held → a refresh is already in progress
228-
};
229-
if last_update.elapsed() <= max_age {
234+
if self.snapshot().created_at.elapsed() <= max_age {
230235
return;
231236
}
232-
if self.refreshing.swap(true, Ordering::AcqRel) {
233-
return; // another refresh is already running
237+
if self.refresh_in_progress.swap(true, Ordering::AcqRel) {
238+
return;
234239
}
235240
let state = Arc::clone(self);
236241
let http_client = http_client.clone();
237242
tokio::task::spawn_local(async move {
238243
match build_oidc_client(&state.config, &http_client).await {
239244
Ok((client, end_session_endpoint)) => {
240-
*state.client.write().await = ClientWithTime {
245+
*state.snapshot.write().unwrap() = Arc::new(OidcSnapshot {
241246
client,
242247
end_session_endpoint,
243-
last_update: Instant::now(),
244-
};
248+
created_at: Instant::now(),
249+
});
245250
}
246251
Err(e) => log::error!("Failed to refresh OIDC client: {e:#}"),
247252
}
248-
state.refreshing.store(false, Ordering::Release);
253+
state
254+
.refresh_in_progress
255+
.store(false, Ordering::Release);
249256
});
250257
}
251258

252-
/// Refreshes the OIDC client from the provider metadata URL if it has expired.
253-
/// Most providers update their signing keys periodically.
254-
pub fn refresh_if_expired(self: &Arc<Self>, http_client: &Client) {
255-
self.refresh_in_background(http_client, OIDC_CLIENT_MAX_REFRESH_INTERVAL);
256-
}
257-
258-
/// When an authentication error is encountered, refresh the OIDC client info faster
259-
pub fn refresh_on_error(self: &Arc<Self>, http_client: &Client) {
260-
self.refresh_in_background(http_client, OIDC_CLIENT_MIN_REFRESH_INTERVAL);
261-
}
262-
263-
/// Gets a reference to the oidc client, potentially generating a new one if needed
264-
pub async fn get_client(&self) -> RwLockReadGuard<'_, OidcClient> {
265-
RwLockReadGuard::map(
266-
self.client.read().await,
267-
|ClientWithTime { client, .. }| client,
268-
)
269-
}
270-
271259
/// Forces the OIDC client to appear stale so that the next request triggers a refresh.
272260
#[doc(hidden)]
273-
pub async fn force_expire(&self) {
274-
self.client.write().await.last_update =
275-
Instant::now()
261+
pub fn force_expire(&self) {
262+
let mut guard = self.snapshot.write().unwrap();
263+
let old = &**guard;
264+
*guard = Arc::new(OidcSnapshot {
265+
client: old.client.clone(),
266+
end_session_endpoint: old.end_session_endpoint.clone(),
267+
created_at: Instant::now()
276268
.checked_sub(OIDC_CLIENT_MAX_REFRESH_INTERVAL + Duration::from_secs(1))
277-
.unwrap_or(Instant::now());
269+
.unwrap_or(Instant::now()),
270+
});
278271
}
279272

280-
pub async fn get_end_session_endpoint(&self) -> Option<EndSessionUrl> {
281-
self.client.read().await.end_session_endpoint.clone()
273+
pub fn end_session_endpoint(&self) -> Option<EndSessionUrl> {
274+
self.snapshot().end_session_endpoint.clone()
282275
}
283276

284-
/// Validate and decode the claims of an OIDC token, without refreshing the client.
285-
async fn get_token_claims(
277+
/// Validate and decode the claims of an OIDC token.
278+
fn get_token_claims(
286279
&self,
287280
id_token: OidcToken,
288281
expected_nonce: &Nonce,
289282
) -> anyhow::Result<OidcClaims> {
290-
let client = &self.get_client().await;
291-
let verifier = self.config.create_id_token_verifier(client);
283+
let snapshot = self.snapshot();
284+
let verifier = self.config.create_id_token_verifier(&snapshot.client);
292285
let nonce_verifier = |nonce: Option<&Nonce>| check_nonce(nonce, expected_nonce);
293286
let claims: OidcClaims = id_token
294287
.into_claims(&verifier, nonce_verifier)
295288
.map_err(|e| anyhow::anyhow!("Could not verify the ID token: {e}"))?;
296289
Ok(claims)
297290
}
298291

299-
/// Builds an absolute redirect URI by joining the relative redirect URI with the client's redirect URL
300-
pub async fn build_absolute_redirect_uri(
292+
/// Builds an absolute redirect URI from the client's configured redirect URL.
293+
pub fn build_absolute_redirect_uri(
301294
&self,
302295
relative_redirect_uri: &str,
303296
) -> anyhow::Result<String> {
304-
let client_guard = self.get_client().await;
305-
let client_redirect_url = client_guard
297+
let snapshot = self.snapshot();
298+
let client_redirect_url = snapshot
299+
.client
306300
.redirect_uri()
307301
.ok_or_else(|| anyhow!("OIDC client has no redirect URL configured"))?;
308302
let absolute_redirect_uri = client_redirect_url
@@ -427,8 +421,9 @@ async fn handle_request(
427421
request: ServiceRequest,
428422
) -> MiddlewareResponse {
429423
log::trace!("Started OIDC middleware request handling");
430-
if let Ok(http_client) = get_http_client_from_appdata(&request) {
431-
oidc_state.refresh_if_expired(http_client);
424+
let http_client = get_http_client_from_appdata(&request).ok();
425+
if let Some(c) = http_client {
426+
oidc_state.maybe_refresh(c, OIDC_CLIENT_MAX_REFRESH_INTERVAL);
432427
}
433428

434429
if request.path() == oidc_state.config.redirect_uri {
@@ -437,31 +432,31 @@ async fn handle_request(
437432
}
438433

439434
if request.path() == oidc_state.config.logout_uri {
440-
let response = handle_oidc_logout(oidc_state, request).await;
435+
let response = handle_oidc_logout(oidc_state, request);
441436
return MiddlewareResponse::Respond(response);
442437
}
443438

444-
match get_authenticated_user_info(oidc_state, &request).await {
439+
match get_authenticated_user_info(oidc_state, &request) {
445440
Ok(Some(claims)) => {
446441
log::trace!("Storing authenticated user info in request extensions: {claims:?}");
447442
request.extensions_mut().insert(claims);
448443
MiddlewareResponse::Forward(request)
449444
}
450445
Ok(None) => {
451446
log::trace!("No authenticated user found");
452-
handle_unauthenticated_request(oidc_state, request).await
447+
handle_unauthenticated_request(oidc_state, request)
453448
}
454449
Err(e) => {
455450
log::debug!("An auth cookie is present but could not be verified. Redirecting to OIDC provider to re-authenticate. {e:?}");
456-
if let Ok(http_client) = get_http_client_from_appdata(&request) {
457-
oidc_state.refresh_on_error(http_client);
451+
if let Some(c) = http_client {
452+
oidc_state.maybe_refresh(c, OIDC_CLIENT_MIN_REFRESH_INTERVAL);
458453
}
459-
handle_unauthenticated_request(oidc_state, request).await
454+
handle_unauthenticated_request(oidc_state, request)
460455
}
461456
}
462457
}
463458

464-
async fn handle_unauthenticated_request(
459+
fn handle_unauthenticated_request(
465460
oidc_state: &OidcState,
466461
request: ServiceRequest,
467462
) -> MiddlewareResponse {
@@ -476,7 +471,7 @@ async fn handle_unauthenticated_request(
476471
let initial_url = request.uri().to_string();
477472
let redirect_count = get_redirect_count(&request);
478473
let response =
479-
build_auth_provider_redirect_response(oidc_state, &initial_url, redirect_count).await;
474+
build_auth_provider_redirect_response(oidc_state, &initial_url, redirect_count);
480475
MiddlewareResponse::Respond(request.into_response(response))
481476
}
482477

@@ -489,26 +484,26 @@ async fn handle_oidc_callback(
489484
clear_redirect_count_cookie(&mut response);
490485
request.into_response(response)
491486
}
492-
Err(e) => handle_oidc_callback_error(oidc_state, request, e).await,
487+
Err(e) => handle_oidc_callback_error(oidc_state, request, &e),
493488
}
494489
}
495490

496-
async fn handle_oidc_callback_error(
491+
fn handle_oidc_callback_error(
497492
oidc_state: &Arc<OidcState>,
498493
request: ServiceRequest,
499-
e: anyhow::Error,
494+
e: &anyhow::Error,
500495
) -> ServiceResponse {
501496
let redirect_count = get_redirect_count(&request);
502497
if redirect_count >= MAX_OIDC_REDIRECTS {
503-
return handle_max_redirect_count_reached(request, &e, redirect_count);
498+
return handle_max_redirect_count_reached(request, e, redirect_count);
504499
}
505500
log::error!(
506501
"Failed to process OIDC callback (attempt {redirect_count}). Refreshing oidc provider metadata, then redirecting to home page: {e:#}"
507502
);
508503
if let Ok(http_client) = get_http_client_from_appdata(&request) {
509-
oidc_state.refresh_on_error(http_client);
504+
oidc_state.maybe_refresh(http_client, OIDC_CLIENT_MIN_REFRESH_INTERVAL);
510505
}
511-
let resp = build_auth_provider_redirect_response(oidc_state, "/", redirect_count).await;
506+
let resp = build_auth_provider_redirect_response(oidc_state, "/", redirect_count);
512507
request.into_response(resp)
513508
}
514509

@@ -525,8 +520,8 @@ fn handle_max_redirect_count_reached(
525520
request.into_response(resp)
526521
}
527522

528-
async fn handle_oidc_logout(oidc_state: &OidcState, request: ServiceRequest) -> ServiceResponse {
529-
match process_oidc_logout(oidc_state, &request).await {
523+
fn handle_oidc_logout(oidc_state: &OidcState, request: ServiceRequest) -> ServiceResponse {
524+
match process_oidc_logout(oidc_state, &request) {
530525
Ok(response) => request.into_response(response),
531526
Err(e) => {
532527
log::error!("Failed to process OIDC logout: {e:#}");
@@ -554,7 +549,7 @@ fn parse_logout_params(query: &str) -> anyhow::Result<LogoutParams> {
554549
.map(Query::into_inner)
555550
}
556551

557-
async fn process_oidc_logout(
552+
fn process_oidc_logout(
558553
oidc_state: &OidcState,
559554
request: &ServiceRequest,
560555
) -> anyhow::Result<HttpResponse> {
@@ -571,10 +566,9 @@ async fn process_oidc_logout(
571566
.flatten();
572567

573568
let mut response =
574-
if let Some(end_session_endpoint) = oidc_state.get_end_session_endpoint().await {
569+
if let Some(end_session_endpoint) = oidc_state.end_session_endpoint() {
575570
let absolute_redirect_uri = oidc_state
576-
.build_absolute_redirect_uri(&params.redirect_uri)
577-
.await?;
571+
.build_absolute_redirect_uri(&params.redirect_uri)?;
578572

579573
let post_logout_redirect_uri =
580574
PostLogoutRedirectUrl::new(absolute_redirect_uri.clone()).with_context(|| {
@@ -692,9 +686,9 @@ async fn process_oidc_callback(
692686
.into_inner();
693687
log::debug!("Processing OIDC callback with params: {params:?}. Requesting token...");
694688
let mut tmp_login_flow_state_cookie = get_tmp_login_flow_state_cookie(request, &params.state)?;
695-
let client = oidc_state.get_client().await;
689+
let snapshot = oidc_state.snapshot();
696690
let http_client = get_http_client_from_appdata(request)?;
697-
let id_token = exchange_code_for_token(&client, http_client, params.clone()).await?;
691+
let id_token = exchange_code_for_token(&snapshot.client, http_client, params.clone()).await?;
698692
log::debug!("Received token response: {id_token:?}");
699693
let LoginFlowState {
700694
nonce,
@@ -708,7 +702,6 @@ async fn process_oidc_callback(
708702
set_auth_cookie(&mut response, &id_token);
709703
let claims = oidc_state
710704
.get_token_claims(id_token, &nonce)
711-
.await
712705
.context("The identity provider returned an invalid ID token")?;
713706
log::debug!("{} successfully logged in", claims.subject().as_str());
714707
let nonce_cookie = create_final_nonce_cookie(&nonce);
@@ -759,12 +752,12 @@ fn set_auth_cookie(response: &mut HttpResponse, id_token: &OidcToken) {
759752
response.add_cookie(&cookie).unwrap();
760753
}
761754

762-
async fn build_auth_provider_redirect_response(
755+
fn build_auth_provider_redirect_response(
763756
oidc_state: &OidcState,
764757
initial_url: &str,
765758
redirect_count: u8,
766759
) -> HttpResponse {
767-
let AuthUrl { url, params } = build_auth_url(oidc_state).await;
760+
let AuthUrl { url, params } = build_auth_url(oidc_state);
768761
let tmp_login_flow_state_cookie = create_tmp_login_flow_state_cookie(&params, initial_url);
769762
let redirect_count_cookie = Cookie::build(
770763
SQLPAGE_OIDC_REDIRECT_COUNT_COOKIE,
@@ -811,7 +804,7 @@ fn build_oidc_error_response(request: &ServiceRequest, e: &anyhow::Error) -> Htt
811804
}
812805

813806
/// Returns the claims from the ID token in the `SQLPage` auth cookie.
814-
async fn get_authenticated_user_info(
807+
fn get_authenticated_user_info(
815808
oidc_state: &OidcState,
816809
request: &ServiceRequest,
817810
) -> anyhow::Result<Option<OidcClaims>> {
@@ -824,7 +817,7 @@ async fn get_authenticated_user_info(
824817

825818
let nonce = get_final_nonce_from_cookie(request)?;
826819
log::debug!("Verifying id token: {id_token:?}");
827-
let claims = oidc_state.get_token_claims(id_token, &nonce).await?;
820+
let claims = oidc_state.get_token_claims(id_token, &nonce)?;
828821
log::debug!("The current user is: {claims:?}");
829822
Ok(Some(claims))
830823
}
@@ -992,12 +985,12 @@ struct AuthUrlParams {
992985
nonce: Nonce,
993986
}
994987

995-
async fn build_auth_url(oidc_state: &OidcState) -> AuthUrl {
988+
fn build_auth_url(oidc_state: &OidcState) -> AuthUrl {
996989
let nonce_source = Nonce::new_random();
997990
let hashed_nonce = Nonce::new(hash_nonce(&nonce_source));
998991
let scopes = &oidc_state.config.scopes;
999-
let client_lock = oidc_state.get_client().await;
1000-
let (url, csrf_token, _nonce) = client_lock
992+
let snapshot = oidc_state.snapshot();
993+
let (url, csrf_token, _nonce) = snapshot.client
1001994
.authorize_url(
1002995
CoreAuthenticationFlow::AuthorizationCode,
1003996
CsrfToken::new_random,

0 commit comments

Comments
 (0)