diff --git a/examples/basic/src/lib.rs b/examples/basic/src/lib.rs index be7daca..4c3a7e6 100644 --- a/examples/basic/src/lib.rs +++ b/examples/basic/src/lib.rs @@ -1,9 +1,13 @@ use axum::{ - error_handling::HandleErrorLayer, http::Uri, response::IntoResponse, routing::get, Router, + error_handling::HandleErrorLayer, + http::Uri, + response::IntoResponse, + routing::{any, get}, + Router, }; use axum_oidc::{ - error::MiddlewareError, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcClient, - OidcLoginLayer, OidcRpInitiatedLogout, + error::MiddlewareError, handle_oidc_redirect, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, + OidcClient, OidcLoginLayer, OidcRpInitiatedLogout, }; use tokio::net::TcpListener; use tower::ServiceBuilder; @@ -33,7 +37,7 @@ pub async fn run( let mut oidc_client = OidcClient::::builder() .with_default_http_client() - .with_application_base_url(Uri::from_maybe_shared(app_url).expect("valid APP_URL")) + .with_redirect_url(Uri::from_static("http://localhost:8080/oidc")) .with_client_id(client_id); if let Some(client_secret) = client_secret { oidc_client = oidc_client.with_client_secret(client_secret); @@ -56,6 +60,7 @@ pub async fn run( .route("/logout", get(logout)) .layer(oidc_login_service) .route("/bar", get(maybe_authenticated)) + .route("/oidc", any(handle_oidc_redirect::)) .layer(oidc_auth_service) .layer(session_layer); diff --git a/src/builder.rs b/src/builder.rs index ba3b351..4e75080 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -6,23 +6,22 @@ use openidconnect::{ClientId, ClientSecret, IssuerUrl}; use crate::{error::Error, AdditionalClaims, Client, OidcClient, ProviderMetadata}; pub struct Unconfigured; -pub struct ApplicationBaseUrl(Uri); pub struct OpenidconnectClient(crate::Client); pub struct HttpClient(reqwest::Client); +pub struct RedirectUrl(Uri); pub struct ClientCredentials { id: Box, secret: Option>, } -pub struct Builder { - application_base_url: ApplicationBaseUrl, +pub struct Builder { credentials: Credentials, client: Client, http_client: HttpClient, + redirect_url: RedirectUrl, end_session_endpoint: Option, scopes: Vec>, - oidc_request_parameters: Vec>, auth_context_class: Option>, _ac: PhantomData, } @@ -35,19 +34,13 @@ impl Default for Builder { impl Builder { /// create a new builder with default values pub fn new() -> Self { - let oidc_request_parameters = ["code", "state", "session_state", "iss"] - .into_iter() - .map(Box::::from) - .collect(); - Self { - application_base_url: (), credentials: (), client: (), http_client: (), + redirect_url: (), end_session_endpoint: None, scopes: vec![Box::from("openid")], - oidc_request_parameters, auth_context_class: None, _ac: PhantomData, } @@ -61,7 +54,7 @@ impl OidcClient { } } -impl Builder { +impl Builder { /// add a scope to existing (default) scopes pub fn add_scope(mut self, scope: impl Into>) -> Self { self.scopes.push(scope.into()); @@ -73,28 +66,6 @@ impl Builder>, - ) -> Self { - self.oidc_request_parameters - .push(oidc_request_parameter.into()); - self - } - - /// replace query parameters that will be filtered from requests (including default) - pub fn with_oidc_request_parameters( - mut self, - oidc_request_parameters: impl Iterator>>, - ) -> Self { - self.oidc_request_parameters = oidc_request_parameters - .map(|x| x.into()) - .collect::>(); - self - } - /// authenticate with Authentication Context Class Reference pub fn with_auth_context_class(mut self, acr: impl Into>) -> Self { self.auth_context_class = Some(acr.into()); @@ -102,50 +73,29 @@ impl Builder Builder { - /// set application base url (e.g. https://example.com) - pub fn with_application_base_url( - self, - url: impl Into, - ) -> Builder { - Builder { - application_base_url: ApplicationBaseUrl(url.into()), - credentials: self.credentials, - client: self.client, - http_client: self.http_client, - end_session_endpoint: self.end_session_endpoint, - scopes: self.scopes, - oidc_request_parameters: self.oidc_request_parameters, - auth_context_class: self.auth_context_class, - _ac: PhantomData, - } - } -} - -impl Builder { +impl Builder { /// set client id for authentication with issuer pub fn with_client_id( self, id: impl Into>, - ) -> Builder { + ) -> Builder { Builder::<_, _, _, _, _> { - application_base_url: self.application_base_url, credentials: ClientCredentials { id: id.into(), secret: None, }, client: self.client, http_client: self.http_client, + redirect_url: self.redirect_url, end_session_endpoint: self.end_session_endpoint, scopes: self.scopes, - oidc_request_parameters: self.oidc_request_parameters, auth_context_class: self.auth_context_class, _ac: PhantomData, } } } -impl Builder { +impl Builder { /// set client secret for authentication with issuer pub fn with_client_secret(mut self, secret: impl Into>) -> Self { self.credentials.secret = Some(secret.into()); @@ -153,47 +103,65 @@ impl Builder Builder { +impl Builder { /// use custom http client pub fn with_http_client( self, client: reqwest::Client, - ) -> Builder { + ) -> Builder { Builder { - application_base_url: self.application_base_url, credentials: self.credentials, client: self.client, http_client: HttpClient(client), + redirect_url: self.redirect_url, end_session_endpoint: self.end_session_endpoint, scopes: self.scopes, - oidc_request_parameters: self.oidc_request_parameters, auth_context_class: self.auth_context_class, _ac: self._ac, } } /// use default reqwest http client - pub fn with_default_http_client(self) -> Builder { + pub fn with_default_http_client(self) -> Builder { Builder { - application_base_url: self.application_base_url, credentials: self.credentials, client: self.client, http_client: HttpClient(reqwest::Client::default()), + redirect_url: self.redirect_url, end_session_endpoint: self.end_session_endpoint, scopes: self.scopes, - oidc_request_parameters: self.oidc_request_parameters, auth_context_class: self.auth_context_class, _ac: self._ac, } } } -impl Builder { +impl Builder { + pub fn with_redirect_url( + self, + redirect_url: Uri, + ) -> Builder { + Builder { + credentials: self.credentials, + client: self.client, + http_client: self.http_client, + redirect_url: RedirectUrl(redirect_url), + end_session_endpoint: self.end_session_endpoint, + scopes: self.scopes, + auth_context_class: self.auth_context_class, + _ac: self._ac, + } + } +} + +impl Builder { /// provide issuer details manually pub fn manual( self, provider_metadata: ProviderMetadata, - ) -> Result, HttpClient>, Error> - { + ) -> Result< + Builder, HttpClient, RedirectUrl>, + Error, + > { let end_session_endpoint = provider_metadata .additional_metadata() .end_session_endpoint @@ -208,16 +176,18 @@ impl Builder Builder, - ) -> Result, HttpClient>, Error> - { + ) -> Result< + Builder, HttpClient, RedirectUrl>, + Error, + > { let issuer_url = IssuerUrl::new(issuer.into().to_string())?; let http_client = self.http_client.0.clone(); let provider_metadata = ProviderMetadata::discover_async(issuer_url, &http_client); @@ -237,17 +209,15 @@ impl Builder - Builder, HttpClient> + Builder, HttpClient, RedirectUrl> { /// create oidc client pub fn build(self) -> OidcClient { OidcClient { scopes: self.scopes, - oidc_request_parameters: self.oidc_request_parameters, client_id: self.credentials.id, client: self.client.0, http_client: self.http_client.0, - application_base_url: self.application_base_url.0, end_session_endpoint: self.end_session_endpoint, auth_context_class: self.auth_context_class, } diff --git a/src/error.rs b/src/error.rs index 0bd1cdc..1bd47d5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -72,6 +72,45 @@ pub enum MiddlewareError { AuthMiddlewareNotFound, } +#[derive(Debug, Error)] +pub enum HandlerError { + #[error("the redirect handler got accessed without a valid session")] + RedirectedWithoutSession, + + #[error("csrf token invalid")] + CsrfTokenInvalid, + + #[error("id token missing")] + IdTokenMissing, + + #[error("access token hash invalid")] + AccessTokenHashInvalid, + + #[error("signing: {0:?}")] + Signing(#[from] openidconnect::SigningError), + + #[error("signature verification: {0:?}")] + Signature(#[from] openidconnect::SignatureVerificationError), + + #[error("session error: {0:?}")] + Session(#[from] tower_sessions::session::Error), + + #[error("configuration: {0:?}")] + Configuration(#[from] openidconnect::ConfigurationError), + + #[error("request token: {0:?}")] + RequestToken( + #[from] + openidconnect::RequestTokenError< + openidconnect::HttpClientError, + StandardErrorResponse, + >, + ), + + #[error("claims verification: {0:?}")] + ClaimsVerification(#[from] openidconnect::ClaimsVerificationError), +} + #[derive(Debug, Error)] pub enum Error { #[error("url parsing: {0:?}")] @@ -93,6 +132,9 @@ pub enum Error { #[error("extractor: {0:?}")] Middleware(#[from] MiddlewareError), + + #[error("handler: {0:?}")] + Handler(#[from] HandlerError), } impl IntoResponse for ExtractorError { @@ -124,3 +166,11 @@ impl IntoResponse for MiddlewareError { } } } + +impl IntoResponse for HandlerError { + fn into_response(self) -> axum_core::response::Response { + match self { + _ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(), + } + } +} diff --git a/src/handler.rs b/src/handler.rs new file mode 100644 index 0000000..c3bbd95 --- /dev/null +++ b/src/handler.rs @@ -0,0 +1,102 @@ +use axum::{extract::Query, response::Redirect, Extension}; +use openidconnect::{ + core::{CoreGenderClaim, CoreJsonWebKey}, + AccessToken, AccessTokenHash, AuthorizationCode, IdTokenClaims, IdTokenVerifier, + OAuth2TokenResponse, PkceCodeVerifier, TokenResponse, +}; +use serde::Deserialize; +use tower_sessions::Session; + +use crate::{ + error::HandlerError, AdditionalClaims, AuthenticatedSession, IdToken, OidcClient, OidcSession, + SESSION_KEY, +}; + +/// response data of the openid issuer after login +#[derive(Debug, Deserialize)] +pub struct OidcQuery { + code: String, + state: String, + #[allow(dead_code)] + session_state: Option, +} + +pub async fn handle_oidc_redirect( + session: Session, + Extension(oidcclient): Extension>, + Query(query): Query, +) -> Result { + let mut login_session: OidcSession = session + .get(SESSION_KEY) + .await? + .ok_or(HandlerError::RedirectedWithoutSession)?; + // the request has the request headers of the oidc redirect + // parse the headers and exchange the code for a valid token + + if login_session.csrf_token.secret() != &query.state { + return Err(HandlerError::CsrfTokenInvalid); + } + + let token_response = oidcclient + .client + .exchange_code(AuthorizationCode::new(query.code.to_string()))? + // Set the PKCE code verifier. + .set_pkce_verifier(PkceCodeVerifier::new( + login_session.pkce_verifier.secret().to_string(), + )) + .request_async(&oidcclient.http_client) + .await?; + + // Extract the ID token claims after verifying its authenticity and nonce. + let id_token = token_response + .id_token() + .ok_or(HandlerError::IdTokenMissing)?; + let id_token_verifier = oidcclient.client.id_token_verifier(); + let claims = id_token.claims(&id_token_verifier, &login_session.nonce)?; + + validate_access_token_hash( + id_token, + id_token_verifier, + token_response.access_token(), + claims, + )?; + + login_session.authenticated = Some(AuthenticatedSession { + id_token: id_token.clone(), + access_token: token_response.access_token().clone(), + }); + let refresh_token = token_response.refresh_token().cloned(); + if let Some(refresh_token) = refresh_token { + login_session.refresh_token = Some(refresh_token); + } + + let redirect_url = login_session.redirect_url.clone(); + session.insert(SESSION_KEY, login_session).await?; + + Ok(Redirect::to(&redirect_url)) +} + +/// Verify the access token hash to ensure that the access token hasn't been substituted for +/// another user's. +/// Returns `Ok` when access token is valid +fn validate_access_token_hash( + id_token: &IdToken, + id_token_verifier: IdTokenVerifier, + access_token: &AccessToken, + claims: &IdTokenClaims, +) -> Result<(), HandlerError> { + if let Some(expected_access_token_hash) = claims.access_token_hash() { + let actual_access_token_hash = AccessTokenHash::from_token( + access_token, + id_token.signing_alg()?, + id_token.signing_key(&id_token_verifier)?, + )?; + if actual_access_token_hash == *expected_access_token_hash { + Ok(()) + } else { + Err(HandlerError::AccessTokenHashInvalid) + } + } else { + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 3319875..dc22366 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,9 +21,11 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize}; pub mod builder; pub mod error; mod extractor; +mod handler; mod middleware; pub use extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}; +pub use handler::handle_oidc_redirect; pub use middleware::{OidcAuthLayer, OidcAuthMiddleware, OidcLoginLayer, OidcLoginMiddleware}; const SESSION_KEY: &str = "axum-oidc"; @@ -100,11 +102,9 @@ pub type BoxError = Box; #[derive(Clone)] pub struct OidcClient { scopes: Vec>, - oidc_request_parameters: Vec>, client_id: Box, client: Client, http_client: reqwest::Client, - application_base_url: Uri, end_session_endpoint: Option, auth_context_class: Option>, } @@ -115,15 +115,6 @@ pub struct EmptyAdditionalClaims {} impl AdditionalClaims for EmptyAdditionalClaims {} impl openidconnect::AdditionalClaims for EmptyAdditionalClaims {} -/// response data of the openid issuer after login -#[derive(Debug, Deserialize)] -struct OidcQuery { - code: String, - state: String, - #[allow(dead_code)] - session_state: Option, -} - /// oidc session #[derive(Serialize, Deserialize, Debug)] #[serde(bound = "AC: Serialize + DeserializeOwned")] @@ -133,6 +124,7 @@ struct OidcSession { pkce_verifier: PkceCodeVerifier, authenticated: Option>, refresh_token: Option, + redirect_url: Box, } #[derive(Serialize, Deserialize, Debug)] diff --git a/src/middleware.rs b/src/middleware.rs index f68e923..5eb14e6 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -3,22 +3,18 @@ use std::{ task::{Context, Poll}, }; -use axum::{ - extract::Query, - response::{IntoResponse, Redirect}, -}; -use axum_core::{extract::FromRequestParts, response::Response}; +use axum::response::{IntoResponse, Redirect}; +use axum_core::response::Response; use futures_util::future::BoxFuture; -use http::{request::Parts, uri::PathAndQuery, Request, Uri}; +use http::{request::Parts, Request}; use tower_layer::Layer; use tower_service::Service; use tower_sessions::Session; use openidconnect::{ core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey}, - AccessToken, AccessTokenHash, AuthenticationContextClass, AuthorizationCode, CsrfToken, - IdTokenClaims, IdTokenVerifier, Nonce, OAuth2TokenResponse, PkceCodeChallenge, - PkceCodeVerifier, RedirectUrl, RefreshToken, + AccessToken, AccessTokenHash, AuthenticationContextClass, CsrfToken, IdTokenClaims, + IdTokenVerifier, Nonce, OAuth2TokenResponse, PkceCodeChallenge, RefreshToken, RequestTokenError::ServerResponse, Scope, TokenResponse, }; @@ -27,7 +23,7 @@ use crate::{ error::MiddlewareError, extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}, AdditionalClaims, AuthenticatedSession, BoxError, ClearSessionFlag, IdToken, OidcClient, - OidcQuery, OidcSession, SESSION_KEY, + OidcSession, SESSION_KEY, }; /// Layer for the [`OidcLoginMiddleware`]. @@ -106,117 +102,53 @@ where } else { // no valid id token or refresh token was found and the user has to login Box::pin(async move { - let (mut parts, _) = request.into_parts(); + let (parts, _) = request.into_parts(); - let mut oidcclient: OidcClient = parts + let oidcclient: OidcClient = parts .extensions .get() .cloned() .ok_or(MiddlewareError::AuthMiddlewareNotFound)?; - let query = Query::::from_request_parts(&mut parts, &()) - .await - .ok(); - let session = parts .extensions .get::() .ok_or(MiddlewareError::SessionNotFound)?; - let login_session: Option> = session - .get(SESSION_KEY) - .await - .map_err(MiddlewareError::from)?; - let handler_uri = strip_oidc_from_path( - oidcclient.application_base_url.clone(), - &parts.uri, - &oidcclient.oidc_request_parameters, - )?; + // generate a login url and redirect the user to it - oidcclient.client = oidcclient - .client - .set_redirect_uri(RedirectUrl::new(handler_uri.to_string())?); + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + let (auth_url, csrf_token, nonce) = { + let mut auth = oidcclient.client.authorize_url( + CoreAuthenticationFlow::AuthorizationCode, + CsrfToken::new_random, + Nonce::new_random, + ); - if let (Some(mut login_session), Some(query)) = (login_session, query) { - // the request has the request headers of the oidc redirect - // parse the headers and exchange the code for a valid token - - if login_session.csrf_token.secret() != &query.state { - return Err(MiddlewareError::CsrfTokenInvalid); + for scope in oidcclient.scopes.iter() { + auth = auth.add_scope(Scope::new(scope.to_string())); } - let token_response = oidcclient - .client - .exchange_code(AuthorizationCode::new(query.code.to_string()))? - // Set the PKCE code verifier. - .set_pkce_verifier(PkceCodeVerifier::new( - login_session.pkce_verifier.secret().to_string(), - )) - .request_async(&oidcclient.http_client) - .await?; - - // Extract the ID token claims after verifying its authenticity and nonce. - let id_token = token_response - .id_token() - .ok_or(MiddlewareError::IdTokenMissing)?; - let id_token_verifier = oidcclient.client.id_token_verifier(); - let claims = id_token.claims(&id_token_verifier, &login_session.nonce)?; - - validate_access_token_hash( - id_token, - id_token_verifier, - token_response.access_token(), - claims, - )?; - - login_session.authenticated = Some(AuthenticatedSession { - id_token: id_token.clone(), - access_token: token_response.access_token().clone(), - }); - let refresh_token = token_response.refresh_token().cloned(); - if let Some(refresh_token) = refresh_token { - login_session.refresh_token = Some(refresh_token); + if let Some(acr) = oidcclient.auth_context_class { + auth = auth + .add_auth_context_value(AuthenticationContextClass::new(acr.into())); } - session.insert(SESSION_KEY, login_session).await?; + auth.set_pkce_challenge(pkce_challenge).url() + }; - Ok(Redirect::temporary(&handler_uri.to_string()).into_response()) - } else { - // generate a login url and redirect the user to it + let oidc_session = OidcSession:: { + nonce, + csrf_token, + pkce_verifier, + authenticated: None, + refresh_token: None, + redirect_url: parts.uri.to_string().into(), + }; - let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); - let (auth_url, csrf_token, nonce) = { - let mut auth = oidcclient.client.authorize_url( - CoreAuthenticationFlow::AuthorizationCode, - CsrfToken::new_random, - Nonce::new_random, - ); + session.insert(SESSION_KEY, oidc_session).await?; - for scope in oidcclient.scopes.iter() { - auth = auth.add_scope(Scope::new(scope.to_string())); - } - - if let Some(acr) = oidcclient.auth_context_class { - auth = auth.add_auth_context_value(AuthenticationContextClass::new( - acr.into(), - )); - } - - auth.set_pkce_challenge(pkce_challenge).url() - }; - - let oidc_session = OidcSession:: { - nonce, - csrf_token, - pkce_verifier, - authenticated: None, - refresh_token: None, - }; - - session.insert(SESSION_KEY, oidc_session).await?; - - Ok(Redirect::temporary(auth_url.as_str()).into_response()) - } + Ok(Redirect::to(auth_url.as_str()).into_response()) }) } } @@ -291,7 +223,7 @@ where fn call(&mut self, request: Request) -> Self::Future { let inner = self.inner.clone(); let mut inner = std::mem::replace(&mut self.inner, inner); - let mut oidcclient = self.client.clone(); + let oidcclient = self.client.clone(); Box::pin(async move { let (mut parts, body) = request.into_parts(); @@ -305,16 +237,6 @@ where .await .map_err(MiddlewareError::from)?; - let handler_uri = strip_oidc_from_path( - oidcclient.application_base_url.clone(), - &parts.uri, - &oidcclient.oidc_request_parameters, - )?; - - oidcclient.client = oidcclient - .client - .set_redirect_uri(RedirectUrl::new(handler_uri.to_string())?); - if let Some(login_session) = &mut login_session { let id_token_claims = login_session.authenticated.as_ref().and_then(|session| { session @@ -329,6 +251,7 @@ where // stored id token is valid and can be used insert_extensions(&mut parts, claims.clone(), &oidcclient, session); } else if let Some(refresh_token) = login_session.refresh_token.as_ref() { + // session is expired but can be refreshed using the refresh_token if let Some((claims, authenticated_session, refresh_token)) = try_refresh_token(&oidcclient, refresh_token, &login_session.nonce).await? { @@ -370,41 +293,6 @@ where } } -/// Helper function to remove the OpenID Connect authentication response query attributes from a -/// [`Uri`]. -pub fn strip_oidc_from_path( - base_url: Uri, - uri: &Uri, - filter: &[Box], -) -> Result { - let mut base_url = base_url.into_parts(); - - base_url.path_and_query = uri - .path_and_query() - .map(|path_and_query| { - let query = path_and_query - .query() - .map(|uri| { - uri.split('&') - .filter(|x| filter.iter().all(|y| !x.starts_with(y.as_ref()))) - .fold(String::default(), |mut acc, x| { - if !acc.is_empty() { - acc += "&"; - } else { - acc += "?"; - } - acc += x; - acc - }) - }) - .unwrap_or_default(); - PathAndQuery::from_maybe_shared(format!("{}{}", path_and_query.path(), query)) - }) - .transpose()?; - - Ok(Uri::from_parts(base_url)?) -} - /// insert all extensions that are used by the extractors fn insert_extensions( parts: &mut Parts,