@@ -198,6 +198,7 @@ pub struct ClientWithTime {
198198pub struct OidcState {
199199 pub config : OidcConfig ,
200200 client : RwLock < ClientWithTime > ,
201+ refreshing : std:: sync:: atomic:: AtomicBool ,
201202}
202203
203204impl OidcState {
@@ -212,36 +213,51 @@ impl OidcState {
212213 end_session_endpoint,
213214 last_update : Instant :: now ( ) ,
214215 } ) ,
216+ refreshing : std:: sync:: atomic:: AtomicBool :: new ( false ) ,
215217 } )
216218 }
217219
218- async fn refresh ( & self , service_request : & ServiceRequest ) {
219- let mut write_guard = self . client . write ( ) . await ;
220- match build_oidc_client_from_appdata ( & self . config , service_request) . await {
221- Ok ( ( http_client, end_session_endpoint) ) => {
222- * write_guard = ClientWithTime {
223- client : http_client,
224- end_session_endpoint,
225- last_update : Instant :: now ( ) ,
220+ /// Spawns a background task to refresh the OIDC client from the provider
221+ /// metadata URL if it hasn't been refreshed within `max_age`.
222+ /// Returns immediately without blocking the caller.
223+ /// Multiple concurrent calls are deduplicated via an atomic flag.
224+ fn refresh_in_background ( self : & Arc < Self > , http_client : & Client , max_age : Duration ) {
225+ use std:: sync:: atomic:: Ordering ;
226+ let Ok ( last_update) = self . client . try_read ( ) . map ( |g| g. last_update ) else {
227+ return ; // write lock held → a refresh is already in progress
228+ } ;
229+ if last_update. elapsed ( ) <= max_age {
230+ return ;
231+ }
232+ if self . refreshing . swap ( true , Ordering :: AcqRel ) {
233+ return ; // another refresh is already running
234+ }
235+ let state = Arc :: clone ( self ) ;
236+ let http_client = http_client. clone ( ) ;
237+ tokio:: task:: spawn_local ( async move {
238+ match build_oidc_client ( & state. config , & http_client) . await {
239+ Ok ( ( client, end_session_endpoint) ) => {
240+ * state. client . write ( ) . await = ClientWithTime {
241+ client,
242+ end_session_endpoint,
243+ last_update : Instant :: now ( ) ,
244+ } ;
226245 }
246+ Err ( e) => log:: error!( "Failed to refresh OIDC client: {e:#}" ) ,
227247 }
228- Err ( e ) => log :: error! ( "Failed to refresh OIDC client: {e:#}" ) ,
229- }
248+ state . refreshing . store ( false , Ordering :: Release ) ;
249+ } ) ;
230250 }
231251
232252 /// Refreshes the OIDC client from the provider metadata URL if it has expired.
233253 /// Most providers update their signing keys periodically.
234- pub async fn refresh_if_expired ( & self , service_request : & ServiceRequest ) {
235- if self . client . read ( ) . await . last_update . elapsed ( ) > OIDC_CLIENT_MAX_REFRESH_INTERVAL {
236- self . refresh ( service_request) . await ;
237- }
254+ pub fn refresh_if_expired ( self : & Arc < Self > , http_client : & Client ) {
255+ self . refresh_in_background ( http_client, OIDC_CLIENT_MAX_REFRESH_INTERVAL ) ;
238256 }
239257
240258 /// When an authentication error is encountered, refresh the OIDC client info faster
241- pub async fn refresh_on_error ( & self , service_request : & ServiceRequest ) {
242- if self . client . read ( ) . await . last_update . elapsed ( ) > OIDC_CLIENT_MIN_REFRESH_INTERVAL {
243- self . refresh ( service_request) . await ;
244- }
259+ pub fn refresh_on_error ( self : & Arc < Self > , http_client : & Client ) {
260+ self . refresh_in_background ( http_client, OIDC_CLIENT_MIN_REFRESH_INTERVAL ) ;
245261 }
246262
247263 /// Gets a reference to the oidc client, potentially generating a new one if needed
@@ -252,6 +268,15 @@ impl OidcState {
252268 )
253269 }
254270
271+ /// Forces the OIDC client to appear stale so that the next request triggers a refresh.
272+ #[ doc( hidden) ]
273+ pub async fn force_expire ( & self ) {
274+ self . client . write ( ) . await . last_update =
275+ Instant :: now ( )
276+ . checked_sub ( OIDC_CLIENT_MAX_REFRESH_INTERVAL + Duration :: from_secs ( 1 ) )
277+ . unwrap_or ( Instant :: now ( ) ) ;
278+ }
279+
255280 pub async fn get_end_session_endpoint ( & self ) -> Option < EndSessionUrl > {
256281 self . client . read ( ) . await . end_session_endpoint . clone ( )
257282 }
@@ -309,14 +334,6 @@ pub async fn initialize_oidc_state(
309334 ) ) )
310335}
311336
312- async fn build_oidc_client_from_appdata (
313- cfg : & OidcConfig ,
314- req : & ServiceRequest ,
315- ) -> anyhow:: Result < ( OidcClient , Option < EndSessionUrl > ) > {
316- let http_client = get_http_client_from_appdata ( req) ?;
317- build_oidc_client ( cfg, http_client) . await
318- }
319-
320337async fn build_oidc_client (
321338 oidc_cfg : & OidcConfig ,
322339 http_client : & Client ,
@@ -405,9 +422,14 @@ enum MiddlewareResponse {
405422 Respond ( ServiceResponse ) ,
406423}
407424
408- async fn handle_request ( oidc_state : & OidcState , request : ServiceRequest ) -> MiddlewareResponse {
425+ async fn handle_request (
426+ oidc_state : & Arc < OidcState > ,
427+ request : ServiceRequest ,
428+ ) -> MiddlewareResponse {
409429 log:: trace!( "Started OIDC middleware request handling" ) ;
410- oidc_state. refresh_if_expired ( & request) . await ;
430+ if let Ok ( http_client) = get_http_client_from_appdata ( & request) {
431+ oidc_state. refresh_if_expired ( http_client) ;
432+ }
411433
412434 if request. path ( ) == oidc_state. config . redirect_uri {
413435 let response = handle_oidc_callback ( oidc_state, request) . await ;
@@ -431,7 +453,9 @@ async fn handle_request(oidc_state: &OidcState, request: ServiceRequest) -> Midd
431453 }
432454 Err ( e) => {
433455 log:: debug!( "An auth cookie is present but could not be verified. Redirecting to OIDC provider to re-authenticate. {e:?}" ) ;
434- oidc_state. refresh_on_error ( & request) . await ;
456+ if let Ok ( http_client) = get_http_client_from_appdata ( & request) {
457+ oidc_state. refresh_on_error ( http_client) ;
458+ }
435459 handle_unauthenticated_request ( oidc_state, request) . await
436460 }
437461 }
@@ -456,7 +480,10 @@ async fn handle_unauthenticated_request(
456480 MiddlewareResponse :: Respond ( request. into_response ( response) )
457481}
458482
459- async fn handle_oidc_callback ( oidc_state : & OidcState , request : ServiceRequest ) -> ServiceResponse {
483+ async fn handle_oidc_callback (
484+ oidc_state : & Arc < OidcState > ,
485+ request : ServiceRequest ,
486+ ) -> ServiceResponse {
460487 match process_oidc_callback ( oidc_state, & request) . await {
461488 Ok ( mut response) => {
462489 clear_redirect_count_cookie ( & mut response) ;
@@ -467,7 +494,7 @@ async fn handle_oidc_callback(oidc_state: &OidcState, request: ServiceRequest) -
467494}
468495
469496async fn handle_oidc_callback_error (
470- oidc_state : & OidcState ,
497+ oidc_state : & Arc < OidcState > ,
471498 request : ServiceRequest ,
472499 e : anyhow:: Error ,
473500) -> ServiceResponse {
@@ -478,7 +505,9 @@ async fn handle_oidc_callback_error(
478505 log:: error!(
479506 "Failed to process OIDC callback (attempt {redirect_count}). Refreshing oidc provider metadata, then redirecting to home page: {e:#}"
480507 ) ;
481- oidc_state. refresh_on_error ( & request) . await ;
508+ if let Ok ( http_client) = get_http_client_from_appdata ( & request) {
509+ oidc_state. refresh_on_error ( http_client) ;
510+ }
482511 let resp = build_auth_provider_redirect_response ( oidc_state, "/" , redirect_count) . await ;
483512 request. into_response ( resp)
484513}
0 commit comments