implement fix for #10

fixed #10 by implementing a flag in the response extensions that
instructs the middleware to clear the session. The flag is automatically
set when using the `OidcRpInitiatedLogout` as a responder.

improved documentation

modified example to reflect api changes
This commit is contained in:
Paul Zinselmeyer 2024-04-20 20:35:04 +02:00
parent a7b76ace76
commit ac3e0caa0b
Signed by: pfzetto
GPG key ID: 142847B253911DB0
5 changed files with 76 additions and 34 deletions

View file

@ -1,9 +1,5 @@
use axum::{ use axum::{
error_handling::HandleErrorLayer, error_handling::HandleErrorLayer, http::Uri, response::IntoResponse, routing::get, Router,
http::Uri,
response::{IntoResponse, Redirect},
routing::get,
Router,
}; };
use axum_oidc::{ use axum_oidc::{
error::MiddlewareError, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcLoginLayer, error::MiddlewareError, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcLoginLayer,
@ -84,9 +80,5 @@ async fn maybe_authenticated(
} }
async fn logout(logout: OidcRpInitiatedLogout) -> impl IntoResponse { async fn logout(logout: OidcRpInitiatedLogout) -> impl IntoResponse {
let logout_uri = logout logout.with_post_logout_redirect(Uri::from_static("https://pfzetto.de"))
.with_post_logout_redirect(Uri::from_static("https://pfzetto.de"))
.uri()
.unwrap();
Redirect::temporary(&logout_uri.to_string())
} }

View file

@ -13,6 +13,9 @@ pub enum ExtractorError {
#[error("rp initiated logout information not found")] #[error("rp initiated logout information not found")]
RpInitiatedLogoutInformationNotFound, RpInitiatedLogoutInformationNotFound,
#[error("could not build rp initiated logout uri")]
FailedToCreateRpInitiatedLogoutUri,
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]
@ -88,6 +91,9 @@ impl IntoResponse for ExtractorError {
Self::RpInitiatedLogoutInformationNotFound => { Self::RpInitiatedLogoutInformationNotFound => {
(StatusCode::INTERNAL_SERVER_ERROR, "intenal server error").into_response() (StatusCode::INTERNAL_SERVER_ERROR, "intenal server error").into_response()
} }
Self::FailedToCreateRpInitiatedLogoutUri => {
(StatusCode::INTERNAL_SERVER_ERROR, "intenal server error").into_response()
}
} }
} }
} }

View file

@ -1,14 +1,15 @@
use std::{borrow::Cow, ops::Deref}; use std::{borrow::Cow, ops::Deref};
use crate::{error::ExtractorError, AdditionalClaims}; use crate::{error::ExtractorError, AdditionalClaims, ClearSessionFlag};
use async_trait::async_trait; use async_trait::async_trait;
use axum_core::extract::FromRequestParts; use axum::response::Redirect;
use axum_core::{extract::FromRequestParts, response::IntoResponse};
use http::{request::Parts, uri::PathAndQuery, Uri}; use http::{request::Parts, uri::PathAndQuery, Uri};
use openidconnect::{core::CoreGenderClaim, IdTokenClaims}; use openidconnect::{core::CoreGenderClaim, IdTokenClaims};
/// Extractor for the OpenID Connect Claims. /// Extractor for the OpenID Connect Claims.
/// ///
/// This Extractor will only return the Claims when the cached session is valid and [crate::middleware::OidcAuthMiddleware] is loaded. /// This Extractor will only return the Claims when the cached session is valid and [`crate::middleware::OidcAuthMiddleware`] is loaded.
#[derive(Clone)] #[derive(Clone)]
pub struct OidcClaims<AC: AdditionalClaims>(pub IdTokenClaims<AC, CoreGenderClaim>); pub struct OidcClaims<AC: AdditionalClaims>(pub IdTokenClaims<AC, CoreGenderClaim>);
@ -48,7 +49,7 @@ where
/// Extractor for the OpenID Connect Access Token. /// Extractor for the OpenID Connect Access Token.
/// ///
/// This Extractor will only return the Access Token when the cached session is valid and [crate::middleware::OidcAuthMiddleware] is loaded. /// This Extractor will only return the Access Token when the cached session is valid and [`crate::middleware::OidcAuthMiddleware`] is loaded.
#[derive(Clone)] #[derive(Clone)]
pub struct OidcAccessToken(pub String); pub struct OidcAccessToken(pub String);
@ -84,7 +85,7 @@ impl AsRef<str> for OidcAccessToken {
/// Extractor for the [OpenID Connect RP-Initialized Logout](https://openid.net/specs/openid-connect-rpinitiated-1_0.html) URL /// Extractor for the [OpenID Connect RP-Initialized Logout](https://openid.net/specs/openid-connect-rpinitiated-1_0.html) URL
/// ///
/// This Extractor will only succed when the cached session is valid, [crate::middleware::OidcAuthMiddleware] is loaded and the issuer supports RP-Initialized Logout. /// This Extractor will only succed when the cached session is valid, [`crate::middleware::OidcAuthMiddleware`] is loaded and the issuer supports RP-Initialized Logout.
#[derive(Clone)] #[derive(Clone)]
pub struct OidcRpInitiatedLogout { pub struct OidcRpInitiatedLogout {
pub(crate) end_session_endpoint: Uri, pub(crate) end_session_endpoint: Uri,
@ -106,7 +107,9 @@ impl OidcRpInitiatedLogout {
self.state = Some(state); self.state = Some(state);
self self
} }
/// get the uri that the client needs to access for logout /// get the uri that the client needs to access for logout. This does **NOT** delete the
/// session in axum-oidc. You should use the [`ClearSessionFlag`] responder or include
/// [`OidcRpInitiatedLogout`] in the response extensions
pub fn uri(&self) -> Result<Uri, http::Error> { pub fn uri(&self) -> Result<Uri, http::Error> {
let mut parts = self.end_session_endpoint.clone().into_parts(); let mut parts = self.end_session_endpoint.clone().into_parts();
@ -159,3 +162,17 @@ where
.ok_or(ExtractorError::Unauthorized) .ok_or(ExtractorError::Unauthorized)
} }
} }
impl IntoResponse for OidcRpInitiatedLogout {
/// redirect to the logout uri and signal the [`crate::middleware::OidcAuthMiddleware`] that
/// the session should be cleared
fn into_response(self) -> axum_core::response::Response {
if let Ok(uri) = self.uri() {
let mut response = Redirect::temporary(&uri.to_string()).into_response();
response.extensions_mut().insert(ClearSessionFlag);
response
} else {
ExtractorError::FailedToCreateRpInitiatedLogoutUri.into_response()
}
}
}

View file

@ -103,6 +103,8 @@ pub struct OidcClient<AC: AdditionalClaims> {
} }
impl<AC: AdditionalClaims> OidcClient<AC> { impl<AC: AdditionalClaims> OidcClient<AC> {
/// create a new [`OidcClient`] by fetching the required information from the
/// `/.well-known/openid-configuration` endpoint of the issuer.
pub async fn discover_new( pub async fn discover_new(
application_base_url: Uri, application_base_url: Uri,
issuer: String, issuer: String,
@ -157,6 +159,7 @@ struct OidcSession<AC: AdditionalClaims> {
csrf_token: CsrfToken, csrf_token: CsrfToken,
pkce_verifier: PkceCodeVerifier, pkce_verifier: PkceCodeVerifier,
authenticated: Option<AuthenticatedSession<AC>>, authenticated: Option<AuthenticatedSession<AC>>,
refresh_token: Option<RefreshToken>,
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
@ -164,7 +167,6 @@ struct OidcSession<AC: AdditionalClaims> {
struct AuthenticatedSession<AC: AdditionalClaims> { struct AuthenticatedSession<AC: AdditionalClaims> {
id_token: IdToken<AC>, id_token: IdToken<AC>,
access_token: AccessToken, access_token: AccessToken,
refresh_token: Option<RefreshToken>,
} }
/// additional metadata that is discovered on client creation via the /// additional metadata that is discovered on client creation via the
@ -174,3 +176,7 @@ struct AdditionalProviderMetadata {
end_session_endpoint: Option<String>, end_session_endpoint: Option<String>,
} }
impl openidconnect::AdditionalProviderMetadata for AdditionalProviderMetadata {} impl openidconnect::AdditionalProviderMetadata for AdditionalProviderMetadata {}
/// response extension flag to signal the [`OidcAuthLayer`] that the session should be cleared.
#[derive(Clone, Copy)]
pub struct ClearSessionFlag;

View file

@ -26,11 +26,11 @@ use openidconnect::{
use crate::{ use crate::{
error::{Error, MiddlewareError}, error::{Error, MiddlewareError},
extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}, extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout},
AdditionalClaims, AuthenticatedSession, BoxError, IdToken, OidcClient, OidcQuery, OidcSession, AdditionalClaims, AuthenticatedSession, BoxError, ClearSessionFlag, IdToken, OidcClient,
SESSION_KEY, OidcQuery, OidcSession, SESSION_KEY,
}; };
/// Layer for the [OidcLoginMiddleware]. /// Layer for the [`OidcLoginMiddleware`].
#[derive(Clone, Default)] #[derive(Clone, Default)]
pub struct OidcLoginLayer<AC> pub struct OidcLoginLayer<AC>
where where
@ -62,7 +62,7 @@ where
} }
/// This middleware forces the user to be authenticated and redirects the user to the OpenID Connect /// This middleware forces the user to be authenticated and redirects the user to the OpenID Connect
/// Issuer to authenticate. This Middleware needs to be loaded afer [OidcAuthMiddleware]. /// Issuer to authenticate. This Middleware needs to be loaded afer [`OidcAuthMiddleware`].
#[derive(Clone)] #[derive(Clone)]
pub struct OidcLoginMiddleware<I, AC> pub struct OidcLoginMiddleware<I, AC>
where where
@ -164,8 +164,11 @@ where
login_session.authenticated = Some(AuthenticatedSession { login_session.authenticated = Some(AuthenticatedSession {
id_token: id_token.clone(), id_token: id_token.clone(),
access_token: token_response.access_token().clone(), access_token: token_response.access_token().clone(),
refresh_token: token_response.refresh_token().cloned(),
}); });
let refresh_token = token_response.refresh_token().cloned();
if let Some(refresh_token) = refresh_token {
login_session.refresh_token = Some(refresh_token);
}
session.insert(SESSION_KEY, login_session).await?; session.insert(SESSION_KEY, login_session).await?;
@ -193,6 +196,7 @@ where
csrf_token, csrf_token,
pkce_verifier, pkce_verifier,
authenticated: None, authenticated: None,
refresh_token: None,
}; };
session.insert(SESSION_KEY, oidc_session).await?; session.insert(SESSION_KEY, oidc_session).await?;
@ -204,7 +208,7 @@ where
} }
} }
/// Layer for the [OidcAuthMiddleware]. /// Layer for the [`OidcAuthMiddleware`].
#[derive(Clone)] #[derive(Clone)]
pub struct OidcAuthLayer<AC> pub struct OidcAuthLayer<AC>
where where
@ -294,7 +298,8 @@ where
let session = parts let session = parts
.extensions .extensions
.get::<Session>() .get::<Session>()
.ok_or(MiddlewareError::SessionNotFound)?; .ok_or(MiddlewareError::SessionNotFound)?
.clone();
let mut login_session: Option<OidcSession<AC>> = session let mut login_session: Option<OidcSession<AC>> = session
.get(SESSION_KEY) .get(SESSION_KEY)
.await .await
@ -320,16 +325,16 @@ where
if let Some((session, claims)) = id_token_claims { if let Some((session, claims)) = id_token_claims {
// stored id token is valid and can be used // stored id token is valid and can be used
insert_extensions(&mut parts, claims.clone(), &oidcclient, session); insert_extensions(&mut parts, claims.clone(), &oidcclient, session);
} else if let Some(refresh_token) = login_session } else if let Some(refresh_token) = login_session.refresh_token.as_ref() {
.authenticated if let Some((claims, authenticated_session, refresh_token)) =
.as_ref()
.and_then(|x| x.refresh_token.as_ref())
{
if let Some((claims, authenticated_session)) =
try_refresh_token(&oidcclient, refresh_token, &login_session.nonce).await? try_refresh_token(&oidcclient, refresh_token, &login_session.nonce).await?
{ {
insert_extensions(&mut parts, claims, &oidcclient, &authenticated_session); insert_extensions(&mut parts, claims, &oidcclient, &authenticated_session);
login_session.authenticated = Some(authenticated_session); login_session.authenticated = Some(authenticated_session);
if let Some(refresh_token) = refresh_token {
login_session.refresh_token = Some(refresh_token);
}
}; };
// save refreshed session or delete it when the token couldn't be refreshed // save refreshed session or delete it when the token couldn't be refreshed
@ -350,6 +355,13 @@ where
.await .await
.map_err(|e| MiddlewareError::NextMiddleware(e.into()))? .map_err(|e| MiddlewareError::NextMiddleware(e.into()))?
.into_response(); .into_response();
let has_logout_ext = response.extensions().get::<ClearSessionFlag>().is_some();
if let (true, Some(mut login_session)) = (has_logout_ext, login_session) {
login_session.authenticated = None;
session.insert(SESSION_KEY, login_session).await?;
}
Ok(response) Ok(response)
}) })
} }
@ -433,8 +445,14 @@ async fn try_refresh_token<AC: AdditionalClaims>(
client: &OidcClient<AC>, client: &OidcClient<AC>,
refresh_token: &RefreshToken, refresh_token: &RefreshToken,
nonce: &Nonce, nonce: &Nonce,
) -> Result<Option<(IdTokenClaims<AC, CoreGenderClaim>, AuthenticatedSession<AC>)>, MiddlewareError> ) -> Result<
{ Option<(
IdTokenClaims<AC, CoreGenderClaim>,
AuthenticatedSession<AC>,
Option<RefreshToken>,
)>,
MiddlewareError,
> {
let mut refresh_request = client.client.exchange_refresh_token(refresh_token); let mut refresh_request = client.client.exchange_refresh_token(refresh_token);
for scope in client.scopes.iter() { for scope in client.scopes.iter() {
@ -454,10 +472,13 @@ async fn try_refresh_token<AC: AdditionalClaims>(
let authenticated_session = AuthenticatedSession { let authenticated_session = AuthenticatedSession {
id_token: id_token.clone(), id_token: id_token.clone(),
access_token: token_response.access_token().clone(), access_token: token_response.access_token().clone(),
refresh_token: token_response.refresh_token().cloned(),
}; };
Ok(Some((claims.clone(), authenticated_session))) Ok(Some((
claims.clone(),
authenticated_session,
token_response.refresh_token().cloned(),
)))
} }
Err(ServerResponse(e)) if *e.error() == CoreErrorResponseType::InvalidGrant => { Err(ServerResponse(e)) if *e.error() == CoreErrorResponseType::InvalidGrant => {
// Refresh failed, refresh_token most likely expired or // Refresh failed, refresh_token most likely expired or