From ac3e0caa0b72b5a424abbc985e0634553cc0df47 Mon Sep 17 00:00:00 2001 From: Paul Zinselmeyer Date: Sat, 20 Apr 2024 20:35:04 +0200 Subject: [PATCH] implement fix for #10 fixed #10 by implementing a flag in the response extensions that instructs the middleware to clear the session. The flag is automatically set when using the `OidcRpInitiatedLogout` as a responder. improved documentation modified example to reflect api changes --- examples/basic/src/main.rs | 12 ++------- src/error.rs | 6 +++++ src/extractor.rs | 29 +++++++++++++++----- src/lib.rs | 8 +++++- src/middleware.rs | 55 ++++++++++++++++++++++++++------------ 5 files changed, 76 insertions(+), 34 deletions(-) diff --git a/examples/basic/src/main.rs b/examples/basic/src/main.rs index da5165f..6659890 100644 --- a/examples/basic/src/main.rs +++ b/examples/basic/src/main.rs @@ -1,9 +1,5 @@ use axum::{ - error_handling::HandleErrorLayer, - http::Uri, - response::{IntoResponse, Redirect}, - routing::get, - Router, + error_handling::HandleErrorLayer, http::Uri, response::IntoResponse, routing::get, Router, }; use axum_oidc::{ error::MiddlewareError, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcLoginLayer, @@ -84,9 +80,5 @@ async fn maybe_authenticated( } async fn logout(logout: OidcRpInitiatedLogout) -> impl IntoResponse { - let logout_uri = logout - .with_post_logout_redirect(Uri::from_static("https://pfzetto.de")) - .uri() - .unwrap(); - Redirect::temporary(&logout_uri.to_string()) + logout.with_post_logout_redirect(Uri::from_static("https://pfzetto.de")) } diff --git a/src/error.rs b/src/error.rs index 6d8997e..c580d16 100644 --- a/src/error.rs +++ b/src/error.rs @@ -13,6 +13,9 @@ pub enum ExtractorError { #[error("rp initiated logout information not found")] RpInitiatedLogoutInformationNotFound, + + #[error("could not build rp initiated logout uri")] + FailedToCreateRpInitiatedLogoutUri, } #[derive(Debug, Error)] @@ -88,6 +91,9 @@ impl IntoResponse for ExtractorError { Self::RpInitiatedLogoutInformationNotFound => { (StatusCode::INTERNAL_SERVER_ERROR, "intenal server error").into_response() } + Self::FailedToCreateRpInitiatedLogoutUri => { + (StatusCode::INTERNAL_SERVER_ERROR, "intenal server error").into_response() + } } } } diff --git a/src/extractor.rs b/src/extractor.rs index bf477c8..233f20d 100644 --- a/src/extractor.rs +++ b/src/extractor.rs @@ -1,14 +1,15 @@ use std::{borrow::Cow, ops::Deref}; -use crate::{error::ExtractorError, AdditionalClaims}; +use crate::{error::ExtractorError, AdditionalClaims, ClearSessionFlag}; use async_trait::async_trait; -use axum_core::extract::FromRequestParts; +use axum::response::Redirect; +use axum_core::{extract::FromRequestParts, response::IntoResponse}; use http::{request::Parts, uri::PathAndQuery, Uri}; use openidconnect::{core::CoreGenderClaim, IdTokenClaims}; /// Extractor for the OpenID Connect Claims. /// -/// This Extractor will only return the Claims when the cached session is valid and [crate::middleware::OidcAuthMiddleware] is loaded. +/// This Extractor will only return the Claims when the cached session is valid and [`crate::middleware::OidcAuthMiddleware`] is loaded. #[derive(Clone)] pub struct OidcClaims(pub IdTokenClaims); @@ -48,7 +49,7 @@ where /// Extractor for the OpenID Connect Access Token. /// -/// This Extractor will only return the Access Token when the cached session is valid and [crate::middleware::OidcAuthMiddleware] is loaded. +/// This Extractor will only return the Access Token when the cached session is valid and [`crate::middleware::OidcAuthMiddleware`] is loaded. #[derive(Clone)] pub struct OidcAccessToken(pub String); @@ -84,7 +85,7 @@ impl AsRef for OidcAccessToken { /// Extractor for the [OpenID Connect RP-Initialized Logout](https://openid.net/specs/openid-connect-rpinitiated-1_0.html) URL /// -/// This Extractor will only succed when the cached session is valid, [crate::middleware::OidcAuthMiddleware] is loaded and the issuer supports RP-Initialized Logout. +/// This Extractor will only succed when the cached session is valid, [`crate::middleware::OidcAuthMiddleware`] is loaded and the issuer supports RP-Initialized Logout. #[derive(Clone)] pub struct OidcRpInitiatedLogout { pub(crate) end_session_endpoint: Uri, @@ -106,7 +107,9 @@ impl OidcRpInitiatedLogout { self.state = Some(state); self } - /// get the uri that the client needs to access for logout + /// get the uri that the client needs to access for logout. This does **NOT** delete the + /// session in axum-oidc. You should use the [`ClearSessionFlag`] responder or include + /// [`OidcRpInitiatedLogout`] in the response extensions pub fn uri(&self) -> Result { let mut parts = self.end_session_endpoint.clone().into_parts(); @@ -159,3 +162,17 @@ where .ok_or(ExtractorError::Unauthorized) } } + +impl IntoResponse for OidcRpInitiatedLogout { + /// redirect to the logout uri and signal the [`crate::middleware::OidcAuthMiddleware`] that + /// the session should be cleared + fn into_response(self) -> axum_core::response::Response { + if let Ok(uri) = self.uri() { + let mut response = Redirect::temporary(&uri.to_string()).into_response(); + response.extensions_mut().insert(ClearSessionFlag); + response + } else { + ExtractorError::FailedToCreateRpInitiatedLogoutUri.into_response() + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 9eb2551..5b861a5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -103,6 +103,8 @@ pub struct OidcClient { } impl OidcClient { + /// create a new [`OidcClient`] by fetching the required information from the + /// `/.well-known/openid-configuration` endpoint of the issuer. pub async fn discover_new( application_base_url: Uri, issuer: String, @@ -157,6 +159,7 @@ struct OidcSession { csrf_token: CsrfToken, pkce_verifier: PkceCodeVerifier, authenticated: Option>, + refresh_token: Option, } #[derive(Serialize, Deserialize, Debug)] @@ -164,7 +167,6 @@ struct OidcSession { struct AuthenticatedSession { id_token: IdToken, access_token: AccessToken, - refresh_token: Option, } /// additional metadata that is discovered on client creation via the @@ -174,3 +176,7 @@ struct AdditionalProviderMetadata { end_session_endpoint: Option, } impl openidconnect::AdditionalProviderMetadata for AdditionalProviderMetadata {} + +/// response extension flag to signal the [`OidcAuthLayer`] that the session should be cleared. +#[derive(Clone, Copy)] +pub struct ClearSessionFlag; diff --git a/src/middleware.rs b/src/middleware.rs index b823e66..97bbc09 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -26,11 +26,11 @@ use openidconnect::{ use crate::{ error::{Error, MiddlewareError}, extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}, - AdditionalClaims, AuthenticatedSession, BoxError, IdToken, OidcClient, OidcQuery, OidcSession, - SESSION_KEY, + AdditionalClaims, AuthenticatedSession, BoxError, ClearSessionFlag, IdToken, OidcClient, + OidcQuery, OidcSession, SESSION_KEY, }; -/// Layer for the [OidcLoginMiddleware]. +/// Layer for the [`OidcLoginMiddleware`]. #[derive(Clone, Default)] pub struct OidcLoginLayer where @@ -62,7 +62,7 @@ where } /// This middleware forces the user to be authenticated and redirects the user to the OpenID Connect -/// Issuer to authenticate. This Middleware needs to be loaded afer [OidcAuthMiddleware]. +/// Issuer to authenticate. This Middleware needs to be loaded afer [`OidcAuthMiddleware`]. #[derive(Clone)] pub struct OidcLoginMiddleware where @@ -164,8 +164,11 @@ where login_session.authenticated = Some(AuthenticatedSession { id_token: id_token.clone(), access_token: token_response.access_token().clone(), - refresh_token: token_response.refresh_token().cloned(), }); + let refresh_token = token_response.refresh_token().cloned(); + if let Some(refresh_token) = refresh_token { + login_session.refresh_token = Some(refresh_token); + } session.insert(SESSION_KEY, login_session).await?; @@ -193,6 +196,7 @@ where csrf_token, pkce_verifier, authenticated: None, + refresh_token: None, }; session.insert(SESSION_KEY, oidc_session).await?; @@ -204,7 +208,7 @@ where } } -/// Layer for the [OidcAuthMiddleware]. +/// Layer for the [`OidcAuthMiddleware`]. #[derive(Clone)] pub struct OidcAuthLayer where @@ -294,7 +298,8 @@ where let session = parts .extensions .get::() - .ok_or(MiddlewareError::SessionNotFound)?; + .ok_or(MiddlewareError::SessionNotFound)? + .clone(); let mut login_session: Option> = session .get(SESSION_KEY) .await @@ -320,16 +325,16 @@ where if let Some((session, claims)) = id_token_claims { // stored id token is valid and can be used insert_extensions(&mut parts, claims.clone(), &oidcclient, session); - } else if let Some(refresh_token) = login_session - .authenticated - .as_ref() - .and_then(|x| x.refresh_token.as_ref()) - { - if let Some((claims, authenticated_session)) = + } else if let Some(refresh_token) = login_session.refresh_token.as_ref() { + if let Some((claims, authenticated_session, refresh_token)) = try_refresh_token(&oidcclient, refresh_token, &login_session.nonce).await? { insert_extensions(&mut parts, claims, &oidcclient, &authenticated_session); login_session.authenticated = Some(authenticated_session); + + if let Some(refresh_token) = refresh_token { + login_session.refresh_token = Some(refresh_token); + } }; // save refreshed session or delete it when the token couldn't be refreshed @@ -350,6 +355,13 @@ where .await .map_err(|e| MiddlewareError::NextMiddleware(e.into()))? .into_response(); + + let has_logout_ext = response.extensions().get::().is_some(); + if let (true, Some(mut login_session)) = (has_logout_ext, login_session) { + login_session.authenticated = None; + session.insert(SESSION_KEY, login_session).await?; + } + Ok(response) }) } @@ -433,8 +445,14 @@ async fn try_refresh_token( client: &OidcClient, refresh_token: &RefreshToken, nonce: &Nonce, -) -> Result, AuthenticatedSession)>, MiddlewareError> -{ +) -> Result< + Option<( + IdTokenClaims, + AuthenticatedSession, + Option, + )>, + MiddlewareError, +> { let mut refresh_request = client.client.exchange_refresh_token(refresh_token); for scope in client.scopes.iter() { @@ -454,10 +472,13 @@ async fn try_refresh_token( let authenticated_session = AuthenticatedSession { id_token: id_token.clone(), access_token: token_response.access_token().clone(), - refresh_token: token_response.refresh_token().cloned(), }; - Ok(Some((claims.clone(), authenticated_session))) + Ok(Some(( + claims.clone(), + authenticated_session, + token_response.refresh_token().cloned(), + ))) } Err(ServerResponse(e)) if *e.error() == CoreErrorResponseType::InvalidGrant => { // Refresh failed, refresh_token most likely expired or