fix: fixed redirect_uri with handler_uri in session

Previously the redirect_uri was the uri of the handler that needed
authentication.
Now one fixed redirect_uri for the entire application is used that will
redirect the user to the correct handler after successful
authentication.
This commit should fix: #28, #27, #26, #21
This commit is contained in:
Paul Zinselmeyer 2025-04-18 12:30:29 +02:00
parent 58369449cf
commit 19adcbabd2
Signed by: pfzetto
GPG key ID: B471A1AF06C895FD
6 changed files with 246 additions and 239 deletions

View file

@ -1,9 +1,13 @@
use axum::{ use axum::{
error_handling::HandleErrorLayer, http::Uri, response::IntoResponse, routing::get, Router, error_handling::HandleErrorLayer,
http::Uri,
response::IntoResponse,
routing::{any, get},
Router,
}; };
use axum_oidc::{ use axum_oidc::{
error::MiddlewareError, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcClient, error::MiddlewareError, handle_oidc_redirect, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims,
OidcLoginLayer, OidcRpInitiatedLogout, OidcClient, OidcLoginLayer, OidcRpInitiatedLogout,
}; };
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tower::ServiceBuilder; use tower::ServiceBuilder;
@ -33,7 +37,7 @@ pub async fn run(
let mut oidc_client = OidcClient::<EmptyAdditionalClaims>::builder() let mut oidc_client = OidcClient::<EmptyAdditionalClaims>::builder()
.with_default_http_client() .with_default_http_client()
.with_application_base_url(Uri::from_maybe_shared(app_url).expect("valid APP_URL")) .with_redirect_url(Uri::from_static("http://localhost:8080/oidc"))
.with_client_id(client_id); .with_client_id(client_id);
if let Some(client_secret) = client_secret { if let Some(client_secret) = client_secret {
oidc_client = oidc_client.with_client_secret(client_secret); oidc_client = oidc_client.with_client_secret(client_secret);
@ -56,6 +60,7 @@ pub async fn run(
.route("/logout", get(logout)) .route("/logout", get(logout))
.layer(oidc_login_service) .layer(oidc_login_service)
.route("/bar", get(maybe_authenticated)) .route("/bar", get(maybe_authenticated))
.route("/oidc", any(handle_oidc_redirect::<EmptyAdditionalClaims>))
.layer(oidc_auth_service) .layer(oidc_auth_service)
.layer(session_layer); .layer(session_layer);

View file

@ -6,23 +6,22 @@ use openidconnect::{ClientId, ClientSecret, IssuerUrl};
use crate::{error::Error, AdditionalClaims, Client, OidcClient, ProviderMetadata}; use crate::{error::Error, AdditionalClaims, Client, OidcClient, ProviderMetadata};
pub struct Unconfigured; pub struct Unconfigured;
pub struct ApplicationBaseUrl(Uri);
pub struct OpenidconnectClient<AC: AdditionalClaims>(crate::Client<AC>); pub struct OpenidconnectClient<AC: AdditionalClaims>(crate::Client<AC>);
pub struct HttpClient(reqwest::Client); pub struct HttpClient(reqwest::Client);
pub struct RedirectUrl(Uri);
pub struct ClientCredentials { pub struct ClientCredentials {
id: Box<str>, id: Box<str>,
secret: Option<Box<str>>, secret: Option<Box<str>>,
} }
pub struct Builder<AC: AdditionalClaims, ApplicationBaseUrl, Credentials, Client, HttpClient> { pub struct Builder<AC: AdditionalClaims, Credentials, Client, HttpClient, RedirectUrl> {
application_base_url: ApplicationBaseUrl,
credentials: Credentials, credentials: Credentials,
client: Client, client: Client,
http_client: HttpClient, http_client: HttpClient,
redirect_url: RedirectUrl,
end_session_endpoint: Option<Uri>, end_session_endpoint: Option<Uri>,
scopes: Vec<Box<str>>, scopes: Vec<Box<str>>,
oidc_request_parameters: Vec<Box<str>>,
auth_context_class: Option<Box<str>>, auth_context_class: Option<Box<str>>,
_ac: PhantomData<AC>, _ac: PhantomData<AC>,
} }
@ -35,19 +34,13 @@ impl<AC: AdditionalClaims> Default for Builder<AC, (), (), (), ()> {
impl<AC: AdditionalClaims> Builder<AC, (), (), (), ()> { impl<AC: AdditionalClaims> Builder<AC, (), (), (), ()> {
/// create a new builder with default values /// create a new builder with default values
pub fn new() -> Self { pub fn new() -> Self {
let oidc_request_parameters = ["code", "state", "session_state", "iss"]
.into_iter()
.map(Box::<str>::from)
.collect();
Self { Self {
application_base_url: (),
credentials: (), credentials: (),
client: (), client: (),
http_client: (), http_client: (),
redirect_url: (),
end_session_endpoint: None, end_session_endpoint: None,
scopes: vec![Box::from("openid")], scopes: vec![Box::from("openid")],
oidc_request_parameters,
auth_context_class: None, auth_context_class: None,
_ac: PhantomData, _ac: PhantomData,
} }
@ -61,7 +54,7 @@ impl<AC: AdditionalClaims> OidcClient<AC> {
} }
} }
impl<AC: AdditionalClaims, APPBASE, CREDS, CLIENT, HTTP> Builder<AC, APPBASE, CREDS, CLIENT, HTTP> { impl<AC: AdditionalClaims, CREDS, CLIENT, HTTP, RURL> Builder<AC, CREDS, CLIENT, HTTP, RURL> {
/// add a scope to existing (default) scopes /// add a scope to existing (default) scopes
pub fn add_scope(mut self, scope: impl Into<Box<str>>) -> Self { pub fn add_scope(mut self, scope: impl Into<Box<str>>) -> Self {
self.scopes.push(scope.into()); self.scopes.push(scope.into());
@ -73,28 +66,6 @@ impl<AC: AdditionalClaims, APPBASE, CREDS, CLIENT, HTTP> Builder<AC, APPBASE, CR
self self
} }
/// add a query parameter that will be filtered from requests to existing (default) filtered
/// query parameters
pub fn add_oidc_request_parameter(
mut self,
oidc_request_parameter: impl Into<Box<str>>,
) -> Self {
self.oidc_request_parameters
.push(oidc_request_parameter.into());
self
}
/// replace query parameters that will be filtered from requests (including default)
pub fn with_oidc_request_parameters(
mut self,
oidc_request_parameters: impl Iterator<Item = impl Into<Box<str>>>,
) -> Self {
self.oidc_request_parameters = oidc_request_parameters
.map(|x| x.into())
.collect::<Vec<_>>();
self
}
/// authenticate with Authentication Context Class Reference /// authenticate with Authentication Context Class Reference
pub fn with_auth_context_class(mut self, acr: impl Into<Box<str>>) -> Self { pub fn with_auth_context_class(mut self, acr: impl Into<Box<str>>) -> Self {
self.auth_context_class = Some(acr.into()); self.auth_context_class = Some(acr.into());
@ -102,50 +73,29 @@ impl<AC: AdditionalClaims, APPBASE, CREDS, CLIENT, HTTP> Builder<AC, APPBASE, CR
} }
} }
impl<AC: AdditionalClaims, CREDS, CLIENT, HTTP> Builder<AC, (), CREDS, CLIENT, HTTP> { impl<AC: AdditionalClaims, CLIENT, HTTP, RURL> Builder<AC, (), CLIENT, HTTP, RURL> {
/// set application base url (e.g. https://example.com)
pub fn with_application_base_url(
self,
url: impl Into<Uri>,
) -> Builder<AC, ApplicationBaseUrl, CREDS, CLIENT, HTTP> {
Builder {
application_base_url: ApplicationBaseUrl(url.into()),
credentials: self.credentials,
client: self.client,
http_client: self.http_client,
end_session_endpoint: self.end_session_endpoint,
scopes: self.scopes,
oidc_request_parameters: self.oidc_request_parameters,
auth_context_class: self.auth_context_class,
_ac: PhantomData,
}
}
}
impl<AC: AdditionalClaims, ABU, CLIENT, HTTP> Builder<AC, ABU, (), CLIENT, HTTP> {
/// set client id for authentication with issuer /// set client id for authentication with issuer
pub fn with_client_id( pub fn with_client_id(
self, self,
id: impl Into<Box<str>>, id: impl Into<Box<str>>,
) -> Builder<AC, ABU, ClientCredentials, CLIENT, HTTP> { ) -> Builder<AC, ClientCredentials, CLIENT, HTTP, RURL> {
Builder::<_, _, _, _, _> { Builder::<_, _, _, _, _> {
application_base_url: self.application_base_url,
credentials: ClientCredentials { credentials: ClientCredentials {
id: id.into(), id: id.into(),
secret: None, secret: None,
}, },
client: self.client, client: self.client,
http_client: self.http_client, http_client: self.http_client,
redirect_url: self.redirect_url,
end_session_endpoint: self.end_session_endpoint, end_session_endpoint: self.end_session_endpoint,
scopes: self.scopes, scopes: self.scopes,
oidc_request_parameters: self.oidc_request_parameters,
auth_context_class: self.auth_context_class, auth_context_class: self.auth_context_class,
_ac: PhantomData, _ac: PhantomData,
} }
} }
} }
impl<AC: AdditionalClaims, ABU, CLIENT, HTTP> Builder<AC, ABU, ClientCredentials, CLIENT, HTTP> { impl<AC: AdditionalClaims, CLIENT, HTTP, RURL> Builder<AC, ClientCredentials, CLIENT, HTTP, RURL> {
/// set client secret for authentication with issuer /// set client secret for authentication with issuer
pub fn with_client_secret(mut self, secret: impl Into<Box<str>>) -> Self { pub fn with_client_secret(mut self, secret: impl Into<Box<str>>) -> Self {
self.credentials.secret = Some(secret.into()); self.credentials.secret = Some(secret.into());
@ -153,47 +103,65 @@ impl<AC: AdditionalClaims, ABU, CLIENT, HTTP> Builder<AC, ABU, ClientCredentials
} }
} }
impl<AC: AdditionalClaims, ABU, CREDS, CLIENT> Builder<AC, ABU, CREDS, CLIENT, ()> { impl<AC: AdditionalClaims, CREDS, CLIENT, RURL> Builder<AC, CREDS, CLIENT, (), RURL> {
/// use custom http client /// use custom http client
pub fn with_http_client( pub fn with_http_client(
self, self,
client: reqwest::Client, client: reqwest::Client,
) -> Builder<AC, ABU, CREDS, CLIENT, HttpClient> { ) -> Builder<AC, CREDS, CLIENT, HttpClient, RURL> {
Builder { Builder {
application_base_url: self.application_base_url,
credentials: self.credentials, credentials: self.credentials,
client: self.client, client: self.client,
http_client: HttpClient(client), http_client: HttpClient(client),
redirect_url: self.redirect_url,
end_session_endpoint: self.end_session_endpoint, end_session_endpoint: self.end_session_endpoint,
scopes: self.scopes, scopes: self.scopes,
oidc_request_parameters: self.oidc_request_parameters,
auth_context_class: self.auth_context_class, auth_context_class: self.auth_context_class,
_ac: self._ac, _ac: self._ac,
} }
} }
/// use default reqwest http client /// use default reqwest http client
pub fn with_default_http_client(self) -> Builder<AC, ABU, CREDS, CLIENT, HttpClient> { pub fn with_default_http_client(self) -> Builder<AC, CREDS, CLIENT, HttpClient, RURL> {
Builder { Builder {
application_base_url: self.application_base_url,
credentials: self.credentials, credentials: self.credentials,
client: self.client, client: self.client,
http_client: HttpClient(reqwest::Client::default()), http_client: HttpClient(reqwest::Client::default()),
redirect_url: self.redirect_url,
end_session_endpoint: self.end_session_endpoint, end_session_endpoint: self.end_session_endpoint,
scopes: self.scopes, scopes: self.scopes,
oidc_request_parameters: self.oidc_request_parameters,
auth_context_class: self.auth_context_class, auth_context_class: self.auth_context_class,
_ac: self._ac, _ac: self._ac,
} }
} }
} }
impl<AC: AdditionalClaims, ABU> Builder<AC, ABU, ClientCredentials, (), HttpClient> { impl<AC: AdditionalClaims, CREDS, CLIENT, HCLIENT> Builder<AC, CREDS, CLIENT, HCLIENT, ()> {
pub fn with_redirect_url(
self,
redirect_url: Uri,
) -> Builder<AC, CREDS, CLIENT, HCLIENT, RedirectUrl> {
Builder {
credentials: self.credentials,
client: self.client,
http_client: self.http_client,
redirect_url: RedirectUrl(redirect_url),
end_session_endpoint: self.end_session_endpoint,
scopes: self.scopes,
auth_context_class: self.auth_context_class,
_ac: self._ac,
}
}
}
impl<AC: AdditionalClaims> Builder<AC, ClientCredentials, (), HttpClient, RedirectUrl> {
/// provide issuer details manually /// provide issuer details manually
pub fn manual( pub fn manual(
self, self,
provider_metadata: ProviderMetadata, provider_metadata: ProviderMetadata,
) -> Result<Builder<AC, ABU, ClientCredentials, OpenidconnectClient<AC>, HttpClient>, Error> ) -> Result<
{ Builder<AC, ClientCredentials, OpenidconnectClient<AC>, HttpClient, RedirectUrl>,
Error,
> {
let end_session_endpoint = provider_metadata let end_session_endpoint = provider_metadata
.additional_metadata() .additional_metadata()
.end_session_endpoint .end_session_endpoint
@ -208,16 +176,18 @@ impl<AC: AdditionalClaims, ABU> Builder<AC, ABU, ClientCredentials, (), HttpClie
.secret .secret
.as_ref() .as_ref()
.map(|x| ClientSecret::new(x.to_string())), .map(|x| ClientSecret::new(x.to_string())),
); )
.set_redirect_uri(openidconnect::RedirectUrl::new(
self.redirect_url.0.to_string(),
)?);
Ok(Builder { Ok(Builder {
application_base_url: self.application_base_url,
credentials: self.credentials, credentials: self.credentials,
client: OpenidconnectClient(client), client: OpenidconnectClient(client),
http_client: self.http_client, http_client: self.http_client,
redirect_url: self.redirect_url,
end_session_endpoint, end_session_endpoint,
scopes: self.scopes, scopes: self.scopes,
oidc_request_parameters: self.oidc_request_parameters,
auth_context_class: self.auth_context_class, auth_context_class: self.auth_context_class,
_ac: self._ac, _ac: self._ac,
}) })
@ -226,8 +196,10 @@ impl<AC: AdditionalClaims, ABU> Builder<AC, ABU, ClientCredentials, (), HttpClie
pub async fn discover( pub async fn discover(
self, self,
issuer: impl Into<Uri>, issuer: impl Into<Uri>,
) -> Result<Builder<AC, ABU, ClientCredentials, OpenidconnectClient<AC>, HttpClient>, Error> ) -> Result<
{ Builder<AC, ClientCredentials, OpenidconnectClient<AC>, HttpClient, RedirectUrl>,
Error,
> {
let issuer_url = IssuerUrl::new(issuer.into().to_string())?; let issuer_url = IssuerUrl::new(issuer.into().to_string())?;
let http_client = self.http_client.0.clone(); let http_client = self.http_client.0.clone();
let provider_metadata = ProviderMetadata::discover_async(issuer_url, &http_client); let provider_metadata = ProviderMetadata::discover_async(issuer_url, &http_client);
@ -237,17 +209,15 @@ impl<AC: AdditionalClaims, ABU> Builder<AC, ABU, ClientCredentials, (), HttpClie
} }
impl<AC: AdditionalClaims> impl<AC: AdditionalClaims>
Builder<AC, ApplicationBaseUrl, ClientCredentials, OpenidconnectClient<AC>, HttpClient> Builder<AC, ClientCredentials, OpenidconnectClient<AC>, HttpClient, RedirectUrl>
{ {
/// create oidc client /// create oidc client
pub fn build(self) -> OidcClient<AC> { pub fn build(self) -> OidcClient<AC> {
OidcClient { OidcClient {
scopes: self.scopes, scopes: self.scopes,
oidc_request_parameters: self.oidc_request_parameters,
client_id: self.credentials.id, client_id: self.credentials.id,
client: self.client.0, client: self.client.0,
http_client: self.http_client.0, http_client: self.http_client.0,
application_base_url: self.application_base_url.0,
end_session_endpoint: self.end_session_endpoint, end_session_endpoint: self.end_session_endpoint,
auth_context_class: self.auth_context_class, auth_context_class: self.auth_context_class,
} }

View file

@ -72,6 +72,45 @@ pub enum MiddlewareError {
AuthMiddlewareNotFound, AuthMiddlewareNotFound,
} }
#[derive(Debug, Error)]
pub enum HandlerError {
#[error("the redirect handler got accessed without a valid session")]
RedirectedWithoutSession,
#[error("csrf token invalid")]
CsrfTokenInvalid,
#[error("id token missing")]
IdTokenMissing,
#[error("access token hash invalid")]
AccessTokenHashInvalid,
#[error("signing: {0:?}")]
Signing(#[from] openidconnect::SigningError),
#[error("signature verification: {0:?}")]
Signature(#[from] openidconnect::SignatureVerificationError),
#[error("session error: {0:?}")]
Session(#[from] tower_sessions::session::Error),
#[error("configuration: {0:?}")]
Configuration(#[from] openidconnect::ConfigurationError),
#[error("request token: {0:?}")]
RequestToken(
#[from]
openidconnect::RequestTokenError<
openidconnect::HttpClientError<openidconnect::reqwest::Error>,
StandardErrorResponse<CoreErrorResponseType>,
>,
),
#[error("claims verification: {0:?}")]
ClaimsVerification(#[from] openidconnect::ClaimsVerificationError),
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum Error { pub enum Error {
#[error("url parsing: {0:?}")] #[error("url parsing: {0:?}")]
@ -93,6 +132,9 @@ pub enum Error {
#[error("extractor: {0:?}")] #[error("extractor: {0:?}")]
Middleware(#[from] MiddlewareError), Middleware(#[from] MiddlewareError),
#[error("handler: {0:?}")]
Handler(#[from] HandlerError),
} }
impl IntoResponse for ExtractorError { impl IntoResponse for ExtractorError {
@ -124,3 +166,11 @@ impl IntoResponse for MiddlewareError {
} }
} }
} }
impl IntoResponse for HandlerError {
fn into_response(self) -> axum_core::response::Response {
match self {
_ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(),
}
}
}

102
src/handler.rs Normal file
View file

@ -0,0 +1,102 @@
use axum::{extract::Query, response::Redirect, Extension};
use openidconnect::{
core::{CoreGenderClaim, CoreJsonWebKey},
AccessToken, AccessTokenHash, AuthorizationCode, IdTokenClaims, IdTokenVerifier,
OAuth2TokenResponse, PkceCodeVerifier, TokenResponse,
};
use serde::Deserialize;
use tower_sessions::Session;
use crate::{
error::HandlerError, AdditionalClaims, AuthenticatedSession, IdToken, OidcClient, OidcSession,
SESSION_KEY,
};
/// response data of the openid issuer after login
#[derive(Debug, Deserialize)]
pub struct OidcQuery {
code: String,
state: String,
#[allow(dead_code)]
session_state: Option<String>,
}
pub async fn handle_oidc_redirect<AC: AdditionalClaims>(
session: Session,
Extension(oidcclient): Extension<OidcClient<AC>>,
Query(query): Query<OidcQuery>,
) -> Result<impl axum::response::IntoResponse, HandlerError> {
let mut login_session: OidcSession<AC> = session
.get(SESSION_KEY)
.await?
.ok_or(HandlerError::RedirectedWithoutSession)?;
// the request has the request headers of the oidc redirect
// parse the headers and exchange the code for a valid token
if login_session.csrf_token.secret() != &query.state {
return Err(HandlerError::CsrfTokenInvalid);
}
let token_response = oidcclient
.client
.exchange_code(AuthorizationCode::new(query.code.to_string()))?
// Set the PKCE code verifier.
.set_pkce_verifier(PkceCodeVerifier::new(
login_session.pkce_verifier.secret().to_string(),
))
.request_async(&oidcclient.http_client)
.await?;
// Extract the ID token claims after verifying its authenticity and nonce.
let id_token = token_response
.id_token()
.ok_or(HandlerError::IdTokenMissing)?;
let id_token_verifier = oidcclient.client.id_token_verifier();
let claims = id_token.claims(&id_token_verifier, &login_session.nonce)?;
validate_access_token_hash(
id_token,
id_token_verifier,
token_response.access_token(),
claims,
)?;
login_session.authenticated = Some(AuthenticatedSession {
id_token: id_token.clone(),
access_token: token_response.access_token().clone(),
});
let refresh_token = token_response.refresh_token().cloned();
if let Some(refresh_token) = refresh_token {
login_session.refresh_token = Some(refresh_token);
}
let redirect_url = login_session.redirect_url.clone();
session.insert(SESSION_KEY, login_session).await?;
Ok(Redirect::to(&redirect_url))
}
/// Verify the access token hash to ensure that the access token hasn't been substituted for
/// another user's.
/// Returns `Ok` when access token is valid
fn validate_access_token_hash<AC: AdditionalClaims>(
id_token: &IdToken<AC>,
id_token_verifier: IdTokenVerifier<CoreJsonWebKey>,
access_token: &AccessToken,
claims: &IdTokenClaims<AC, CoreGenderClaim>,
) -> Result<(), HandlerError> {
if let Some(expected_access_token_hash) = claims.access_token_hash() {
let actual_access_token_hash = AccessTokenHash::from_token(
access_token,
id_token.signing_alg()?,
id_token.signing_key(&id_token_verifier)?,
)?;
if actual_access_token_hash == *expected_access_token_hash {
Ok(())
} else {
Err(HandlerError::AccessTokenHashInvalid)
}
} else {
Ok(())
}
}

View file

@ -21,9 +21,11 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize};
pub mod builder; pub mod builder;
pub mod error; pub mod error;
mod extractor; mod extractor;
mod handler;
mod middleware; mod middleware;
pub use extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}; pub use extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout};
pub use handler::handle_oidc_redirect;
pub use middleware::{OidcAuthLayer, OidcAuthMiddleware, OidcLoginLayer, OidcLoginMiddleware}; pub use middleware::{OidcAuthLayer, OidcAuthMiddleware, OidcLoginLayer, OidcLoginMiddleware};
const SESSION_KEY: &str = "axum-oidc"; const SESSION_KEY: &str = "axum-oidc";
@ -100,11 +102,9 @@ pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
#[derive(Clone)] #[derive(Clone)]
pub struct OidcClient<AC: AdditionalClaims> { pub struct OidcClient<AC: AdditionalClaims> {
scopes: Vec<Box<str>>, scopes: Vec<Box<str>>,
oidc_request_parameters: Vec<Box<str>>,
client_id: Box<str>, client_id: Box<str>,
client: Client<AC>, client: Client<AC>,
http_client: reqwest::Client, http_client: reqwest::Client,
application_base_url: Uri,
end_session_endpoint: Option<Uri>, end_session_endpoint: Option<Uri>,
auth_context_class: Option<Box<str>>, auth_context_class: Option<Box<str>>,
} }
@ -115,15 +115,6 @@ pub struct EmptyAdditionalClaims {}
impl AdditionalClaims for EmptyAdditionalClaims {} impl AdditionalClaims for EmptyAdditionalClaims {}
impl openidconnect::AdditionalClaims for EmptyAdditionalClaims {} impl openidconnect::AdditionalClaims for EmptyAdditionalClaims {}
/// response data of the openid issuer after login
#[derive(Debug, Deserialize)]
struct OidcQuery {
code: String,
state: String,
#[allow(dead_code)]
session_state: Option<String>,
}
/// oidc session /// oidc session
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
#[serde(bound = "AC: Serialize + DeserializeOwned")] #[serde(bound = "AC: Serialize + DeserializeOwned")]
@ -133,6 +124,7 @@ struct OidcSession<AC: AdditionalClaims> {
pkce_verifier: PkceCodeVerifier, pkce_verifier: PkceCodeVerifier,
authenticated: Option<AuthenticatedSession<AC>>, authenticated: Option<AuthenticatedSession<AC>>,
refresh_token: Option<RefreshToken>, refresh_token: Option<RefreshToken>,
redirect_url: Box<str>,
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]

View file

@ -3,22 +3,18 @@ use std::{
task::{Context, Poll}, task::{Context, Poll},
}; };
use axum::{ use axum::response::{IntoResponse, Redirect};
extract::Query, use axum_core::response::Response;
response::{IntoResponse, Redirect},
};
use axum_core::{extract::FromRequestParts, response::Response};
use futures_util::future::BoxFuture; use futures_util::future::BoxFuture;
use http::{request::Parts, uri::PathAndQuery, Request, Uri}; use http::{request::Parts, Request};
use tower_layer::Layer; use tower_layer::Layer;
use tower_service::Service; use tower_service::Service;
use tower_sessions::Session; use tower_sessions::Session;
use openidconnect::{ use openidconnect::{
core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey}, core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey},
AccessToken, AccessTokenHash, AuthenticationContextClass, AuthorizationCode, CsrfToken, AccessToken, AccessTokenHash, AuthenticationContextClass, CsrfToken, IdTokenClaims,
IdTokenClaims, IdTokenVerifier, Nonce, OAuth2TokenResponse, PkceCodeChallenge, IdTokenVerifier, Nonce, OAuth2TokenResponse, PkceCodeChallenge, RefreshToken,
PkceCodeVerifier, RedirectUrl, RefreshToken,
RequestTokenError::ServerResponse, RequestTokenError::ServerResponse,
Scope, TokenResponse, Scope, TokenResponse,
}; };
@ -27,7 +23,7 @@ use crate::{
error::MiddlewareError, error::MiddlewareError,
extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}, extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout},
AdditionalClaims, AuthenticatedSession, BoxError, ClearSessionFlag, IdToken, OidcClient, AdditionalClaims, AuthenticatedSession, BoxError, ClearSessionFlag, IdToken, OidcClient,
OidcQuery, OidcSession, SESSION_KEY, OidcSession, SESSION_KEY,
}; };
/// Layer for the [`OidcLoginMiddleware`]. /// Layer for the [`OidcLoginMiddleware`].
@ -106,117 +102,53 @@ where
} else { } else {
// no valid id token or refresh token was found and the user has to login // no valid id token or refresh token was found and the user has to login
Box::pin(async move { Box::pin(async move {
let (mut parts, _) = request.into_parts(); let (parts, _) = request.into_parts();
let mut oidcclient: OidcClient<AC> = parts let oidcclient: OidcClient<AC> = parts
.extensions .extensions
.get() .get()
.cloned() .cloned()
.ok_or(MiddlewareError::AuthMiddlewareNotFound)?; .ok_or(MiddlewareError::AuthMiddlewareNotFound)?;
let query = Query::<OidcQuery>::from_request_parts(&mut parts, &())
.await
.ok();
let session = parts let session = parts
.extensions .extensions
.get::<Session>() .get::<Session>()
.ok_or(MiddlewareError::SessionNotFound)?; .ok_or(MiddlewareError::SessionNotFound)?;
let login_session: Option<OidcSession<AC>> = session
.get(SESSION_KEY)
.await
.map_err(MiddlewareError::from)?;
let handler_uri = strip_oidc_from_path( // generate a login url and redirect the user to it
oidcclient.application_base_url.clone(),
&parts.uri,
&oidcclient.oidc_request_parameters,
)?;
oidcclient.client = oidcclient let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
.client let (auth_url, csrf_token, nonce) = {
.set_redirect_uri(RedirectUrl::new(handler_uri.to_string())?); let mut auth = oidcclient.client.authorize_url(
CoreAuthenticationFlow::AuthorizationCode,
CsrfToken::new_random,
Nonce::new_random,
);
if let (Some(mut login_session), Some(query)) = (login_session, query) { for scope in oidcclient.scopes.iter() {
// the request has the request headers of the oidc redirect auth = auth.add_scope(Scope::new(scope.to_string()));
// parse the headers and exchange the code for a valid token
if login_session.csrf_token.secret() != &query.state {
return Err(MiddlewareError::CsrfTokenInvalid);
} }
let token_response = oidcclient if let Some(acr) = oidcclient.auth_context_class {
.client auth = auth
.exchange_code(AuthorizationCode::new(query.code.to_string()))? .add_auth_context_value(AuthenticationContextClass::new(acr.into()));
// Set the PKCE code verifier.
.set_pkce_verifier(PkceCodeVerifier::new(
login_session.pkce_verifier.secret().to_string(),
))
.request_async(&oidcclient.http_client)
.await?;
// Extract the ID token claims after verifying its authenticity and nonce.
let id_token = token_response
.id_token()
.ok_or(MiddlewareError::IdTokenMissing)?;
let id_token_verifier = oidcclient.client.id_token_verifier();
let claims = id_token.claims(&id_token_verifier, &login_session.nonce)?;
validate_access_token_hash(
id_token,
id_token_verifier,
token_response.access_token(),
claims,
)?;
login_session.authenticated = Some(AuthenticatedSession {
id_token: id_token.clone(),
access_token: token_response.access_token().clone(),
});
let refresh_token = token_response.refresh_token().cloned();
if let Some(refresh_token) = refresh_token {
login_session.refresh_token = Some(refresh_token);
} }
session.insert(SESSION_KEY, login_session).await?; auth.set_pkce_challenge(pkce_challenge).url()
};
Ok(Redirect::temporary(&handler_uri.to_string()).into_response()) let oidc_session = OidcSession::<AC> {
} else { nonce,
// generate a login url and redirect the user to it csrf_token,
pkce_verifier,
authenticated: None,
refresh_token: None,
redirect_url: parts.uri.to_string().into(),
};
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); session.insert(SESSION_KEY, oidc_session).await?;
let (auth_url, csrf_token, nonce) = {
let mut auth = oidcclient.client.authorize_url(
CoreAuthenticationFlow::AuthorizationCode,
CsrfToken::new_random,
Nonce::new_random,
);
for scope in oidcclient.scopes.iter() { Ok(Redirect::to(auth_url.as_str()).into_response())
auth = auth.add_scope(Scope::new(scope.to_string()));
}
if let Some(acr) = oidcclient.auth_context_class {
auth = auth.add_auth_context_value(AuthenticationContextClass::new(
acr.into(),
));
}
auth.set_pkce_challenge(pkce_challenge).url()
};
let oidc_session = OidcSession::<AC> {
nonce,
csrf_token,
pkce_verifier,
authenticated: None,
refresh_token: None,
};
session.insert(SESSION_KEY, oidc_session).await?;
Ok(Redirect::temporary(auth_url.as_str()).into_response())
}
}) })
} }
} }
@ -291,7 +223,7 @@ where
fn call(&mut self, request: Request<B>) -> Self::Future { fn call(&mut self, request: Request<B>) -> Self::Future {
let inner = self.inner.clone(); let inner = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, inner); let mut inner = std::mem::replace(&mut self.inner, inner);
let mut oidcclient = self.client.clone(); let oidcclient = self.client.clone();
Box::pin(async move { Box::pin(async move {
let (mut parts, body) = request.into_parts(); let (mut parts, body) = request.into_parts();
@ -305,16 +237,6 @@ where
.await .await
.map_err(MiddlewareError::from)?; .map_err(MiddlewareError::from)?;
let handler_uri = strip_oidc_from_path(
oidcclient.application_base_url.clone(),
&parts.uri,
&oidcclient.oidc_request_parameters,
)?;
oidcclient.client = oidcclient
.client
.set_redirect_uri(RedirectUrl::new(handler_uri.to_string())?);
if let Some(login_session) = &mut login_session { if let Some(login_session) = &mut login_session {
let id_token_claims = login_session.authenticated.as_ref().and_then(|session| { let id_token_claims = login_session.authenticated.as_ref().and_then(|session| {
session session
@ -329,6 +251,7 @@ where
// stored id token is valid and can be used // stored id token is valid and can be used
insert_extensions(&mut parts, claims.clone(), &oidcclient, session); insert_extensions(&mut parts, claims.clone(), &oidcclient, session);
} else if let Some(refresh_token) = login_session.refresh_token.as_ref() { } else if let Some(refresh_token) = login_session.refresh_token.as_ref() {
// session is expired but can be refreshed using the refresh_token
if let Some((claims, authenticated_session, refresh_token)) = if let Some((claims, authenticated_session, refresh_token)) =
try_refresh_token(&oidcclient, refresh_token, &login_session.nonce).await? try_refresh_token(&oidcclient, refresh_token, &login_session.nonce).await?
{ {
@ -370,41 +293,6 @@ where
} }
} }
/// Helper function to remove the OpenID Connect authentication response query attributes from a
/// [`Uri`].
pub fn strip_oidc_from_path(
base_url: Uri,
uri: &Uri,
filter: &[Box<str>],
) -> Result<Uri, MiddlewareError> {
let mut base_url = base_url.into_parts();
base_url.path_and_query = uri
.path_and_query()
.map(|path_and_query| {
let query = path_and_query
.query()
.map(|uri| {
uri.split('&')
.filter(|x| filter.iter().all(|y| !x.starts_with(y.as_ref())))
.fold(String::default(), |mut acc, x| {
if !acc.is_empty() {
acc += "&";
} else {
acc += "?";
}
acc += x;
acc
})
})
.unwrap_or_default();
PathAndQuery::from_maybe_shared(format!("{}{}", path_and_query.path(), query))
})
.transpose()?;
Ok(Uri::from_parts(base_url)?)
}
/// insert all extensions that are used by the extractors /// insert all extensions that are used by the extractors
fn insert_extensions<AC: AdditionalClaims>( fn insert_extensions<AC: AdditionalClaims>(
parts: &mut Parts, parts: &mut Parts,