diff --git a/examples/basic/src/lib.rs b/examples/basic/src/lib.rs index 389d742..be7daca 100644 --- a/examples/basic/src/lib.rs +++ b/examples/basic/src/lib.rs @@ -2,8 +2,8 @@ use axum::{ error_handling::HandleErrorLayer, http::Uri, response::IntoResponse, routing::get, Router, }; use axum_oidc::{ - error::MiddlewareError, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcLoginLayer, - OidcRpInitiatedLogout, + error::MiddlewareError, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcClient, + OidcLoginLayer, OidcRpInitiatedLogout, }; use tokio::net::TcpListener; use tower::ServiceBuilder; @@ -26,25 +26,30 @@ pub async fn run( let oidc_login_service = ServiceBuilder::new() .layer(HandleErrorLayer::new(|e: MiddlewareError| async { + dbg!(&e); e.into_response() })) .layer(OidcLoginLayer::::new()); + let mut oidc_client = OidcClient::::builder() + .with_default_http_client() + .with_application_base_url(Uri::from_maybe_shared(app_url).expect("valid APP_URL")) + .with_client_id(client_id); + if let Some(client_secret) = client_secret { + oidc_client = oidc_client.with_client_secret(client_secret); + } + let oidc_client = oidc_client + .discover(Uri::from_maybe_shared(issuer).expect("valid issuer URI")) + .await + .unwrap() + .build(); + let oidc_auth_service = ServiceBuilder::new() .layer(HandleErrorLayer::new(|e: MiddlewareError| async { + dbg!(&e); e.into_response() })) - .layer( - OidcAuthLayer::::discover_client( - Uri::from_maybe_shared(app_url).expect("valid APP_URL"), - issuer, - client_id, - client_secret, - vec![], - ) - .await - .unwrap(), - ); + .layer(OidcAuthLayer::new(oidc_client)); let app = Router::new() .route("/foo", get(authenticated)) @@ -79,5 +84,5 @@ async fn maybe_authenticated( } async fn logout(logout: OidcRpInitiatedLogout) -> impl IntoResponse { - logout.with_post_logout_redirect(Uri::from_static("https://pfzetto.de")) + logout.with_post_logout_redirect(Uri::from_static("https://example.com")) } diff --git a/src/builder.rs b/src/builder.rs new file mode 100644 index 0000000..ba3b351 --- /dev/null +++ b/src/builder.rs @@ -0,0 +1,255 @@ +use std::marker::PhantomData; + +use http::Uri; +use openidconnect::{ClientId, ClientSecret, IssuerUrl}; + +use crate::{error::Error, AdditionalClaims, Client, OidcClient, ProviderMetadata}; + +pub struct Unconfigured; +pub struct ApplicationBaseUrl(Uri); +pub struct OpenidconnectClient(crate::Client); +pub struct HttpClient(reqwest::Client); + +pub struct ClientCredentials { + id: Box, + secret: Option>, +} + +pub struct Builder { + application_base_url: ApplicationBaseUrl, + credentials: Credentials, + client: Client, + http_client: HttpClient, + end_session_endpoint: Option, + scopes: Vec>, + oidc_request_parameters: Vec>, + auth_context_class: Option>, + _ac: PhantomData, +} + +impl Default for Builder { + fn default() -> Self { + Self::new() + } +} +impl Builder { + /// create a new builder with default values + pub fn new() -> Self { + let oidc_request_parameters = ["code", "state", "session_state", "iss"] + .into_iter() + .map(Box::::from) + .collect(); + + Self { + application_base_url: (), + credentials: (), + client: (), + http_client: (), + end_session_endpoint: None, + scopes: vec![Box::from("openid")], + oidc_request_parameters, + auth_context_class: None, + _ac: PhantomData, + } + } +} + +impl OidcClient { + /// create a new builder with default values + pub fn builder() -> Builder { + Builder::::new() + } +} + +impl Builder { + /// add a scope to existing (default) scopes + pub fn add_scope(mut self, scope: impl Into>) -> Self { + self.scopes.push(scope.into()); + self + } + /// replace scopes (including default) + pub fn with_scopes(mut self, scopes: impl Iterator>>) -> Self { + self.scopes = scopes.map(|x| x.into()).collect::>(); + self + } + + /// add a query parameter that will be filtered from requests to existing (default) filtered + /// query parameters + pub fn add_oidc_request_parameter( + mut self, + oidc_request_parameter: impl Into>, + ) -> Self { + self.oidc_request_parameters + .push(oidc_request_parameter.into()); + self + } + + /// replace query parameters that will be filtered from requests (including default) + pub fn with_oidc_request_parameters( + mut self, + oidc_request_parameters: impl Iterator>>, + ) -> Self { + self.oidc_request_parameters = oidc_request_parameters + .map(|x| x.into()) + .collect::>(); + self + } + + /// authenticate with Authentication Context Class Reference + pub fn with_auth_context_class(mut self, acr: impl Into>) -> Self { + self.auth_context_class = Some(acr.into()); + self + } +} + +impl Builder { + /// set application base url (e.g. https://example.com) + pub fn with_application_base_url( + self, + url: impl Into, + ) -> Builder { + Builder { + application_base_url: ApplicationBaseUrl(url.into()), + credentials: self.credentials, + client: self.client, + http_client: self.http_client, + end_session_endpoint: self.end_session_endpoint, + scopes: self.scopes, + oidc_request_parameters: self.oidc_request_parameters, + auth_context_class: self.auth_context_class, + _ac: PhantomData, + } + } +} + +impl Builder { + /// set client id for authentication with issuer + pub fn with_client_id( + self, + id: impl Into>, + ) -> Builder { + Builder::<_, _, _, _, _> { + application_base_url: self.application_base_url, + credentials: ClientCredentials { + id: id.into(), + secret: None, + }, + client: self.client, + http_client: self.http_client, + end_session_endpoint: self.end_session_endpoint, + scopes: self.scopes, + oidc_request_parameters: self.oidc_request_parameters, + auth_context_class: self.auth_context_class, + _ac: PhantomData, + } + } +} + +impl Builder { + /// set client secret for authentication with issuer + pub fn with_client_secret(mut self, secret: impl Into>) -> Self { + self.credentials.secret = Some(secret.into()); + self + } +} + +impl Builder { + /// use custom http client + pub fn with_http_client( + self, + client: reqwest::Client, + ) -> Builder { + Builder { + application_base_url: self.application_base_url, + credentials: self.credentials, + client: self.client, + http_client: HttpClient(client), + end_session_endpoint: self.end_session_endpoint, + scopes: self.scopes, + oidc_request_parameters: self.oidc_request_parameters, + auth_context_class: self.auth_context_class, + _ac: self._ac, + } + } + /// use default reqwest http client + pub fn with_default_http_client(self) -> Builder { + Builder { + application_base_url: self.application_base_url, + credentials: self.credentials, + client: self.client, + http_client: HttpClient(reqwest::Client::default()), + end_session_endpoint: self.end_session_endpoint, + scopes: self.scopes, + oidc_request_parameters: self.oidc_request_parameters, + auth_context_class: self.auth_context_class, + _ac: self._ac, + } + } +} + +impl Builder { + /// provide issuer details manually + pub fn manual( + self, + provider_metadata: ProviderMetadata, + ) -> Result, HttpClient>, Error> + { + let end_session_endpoint = provider_metadata + .additional_metadata() + .end_session_endpoint + .clone() + .map(Uri::from_maybe_shared) + .transpose() + .map_err(Error::InvalidEndSessionEndpoint)?; + let client = Client::from_provider_metadata( + provider_metadata, + ClientId::new(self.credentials.id.to_string()), + self.credentials + .secret + .as_ref() + .map(|x| ClientSecret::new(x.to_string())), + ); + + Ok(Builder { + application_base_url: self.application_base_url, + credentials: self.credentials, + client: OpenidconnectClient(client), + http_client: self.http_client, + end_session_endpoint, + scopes: self.scopes, + oidc_request_parameters: self.oidc_request_parameters, + auth_context_class: self.auth_context_class, + _ac: self._ac, + }) + } + /// discover issuer details + pub async fn discover( + self, + issuer: impl Into, + ) -> Result, HttpClient>, Error> + { + let issuer_url = IssuerUrl::new(issuer.into().to_string())?; + let http_client = self.http_client.0.clone(); + let provider_metadata = ProviderMetadata::discover_async(issuer_url, &http_client); + + Self::manual(self, provider_metadata.await?) + } +} + +impl + Builder, HttpClient> +{ + /// create oidc client + pub fn build(self) -> OidcClient { + OidcClient { + scopes: self.scopes, + oidc_request_parameters: self.oidc_request_parameters, + client_id: self.credentials.id, + client: self.client.0, + http_client: self.http_client.0, + application_base_url: self.application_base_url.0, + end_session_endpoint: self.end_session_endpoint, + auth_context_class: self.auth_context_class, + } + } +} diff --git a/src/error.rs b/src/error.rs index 48f0903..0bd1cdc 100644 --- a/src/error.rs +++ b/src/error.rs @@ -111,7 +111,6 @@ impl IntoResponse for ExtractorError { impl IntoResponse for Error { fn into_response(self) -> axum_core::response::Response { - dbg!(&self); match self { _ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(), } @@ -120,7 +119,6 @@ impl IntoResponse for Error { impl IntoResponse for MiddlewareError { fn into_response(self) -> axum_core::response::Response { - dbg!(&self); match self { _ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(), } diff --git a/src/extractor.rs b/src/extractor.rs index 01635a8..8735ee7 100644 --- a/src/extractor.rs +++ b/src/extractor.rs @@ -112,8 +112,8 @@ impl AsRef for OidcAccessToken { #[derive(Clone)] pub struct OidcRpInitiatedLogout { pub(crate) end_session_endpoint: Uri, - pub(crate) id_token_hint: String, - pub(crate) client_id: String, + pub(crate) id_token_hint: Box, + pub(crate) client_id: Box, pub(crate) post_logout_redirect_uri: Option, pub(crate) state: Option, } diff --git a/src/lib.rs b/src/lib.rs index bf3c15e..3319875 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,6 @@ #![deny(warnings)] #![doc = include_str!("../README.md")] -use crate::error::Error; use http::Uri; use openidconnect::{ core::{ @@ -13,12 +12,13 @@ use openidconnect::{ CoreResponseMode, CoreResponseType, CoreRevocableToken, CoreRevocationErrorResponse, CoreSubjectIdentifierType, CoreTokenIntrospectionResponse, CoreTokenType, }, - AccessToken, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, EndpointMaybeSet, - EndpointNotSet, EndpointSet, IdTokenFields, IssuerUrl, Nonce, PkceCodeVerifier, RefreshToken, - StandardErrorResponse, StandardTokenResponse, + AccessToken, CsrfToken, EmptyExtraTokenFields, EndpointMaybeSet, EndpointNotSet, EndpointSet, + IdTokenFields, Nonce, PkceCodeVerifier, RefreshToken, StandardErrorResponse, + StandardTokenResponse, }; use serde::{de::DeserializeOwned, Deserialize, Serialize}; +pub mod builder; pub mod error; mod extractor; mod middleware; @@ -99,118 +99,14 @@ pub type BoxError = Box; /// OpenID Connect Client #[derive(Clone)] pub struct OidcClient { - scopes: Vec, - client_id: String, + scopes: Vec>, + oidc_request_parameters: Vec>, + client_id: Box, client: Client, http_client: reqwest::Client, application_base_url: Uri, end_session_endpoint: Option, -} - -impl OidcClient { - /// create a new [`OidcClient`] from an existing [`ProviderMetadata`]. - pub fn from_provider_metadata( - provider_metadata: ProviderMetadata, - application_base_url: Uri, - client_id: String, - client_secret: Option, - scopes: Vec, - ) -> Result { - let end_session_endpoint = provider_metadata - .additional_metadata() - .end_session_endpoint - .clone() - .map(Uri::from_maybe_shared) - .transpose() - .map_err(Error::InvalidEndSessionEndpoint)?; - let client = Client::from_provider_metadata( - provider_metadata, - ClientId::new(client_id.clone()), - client_secret.map(ClientSecret::new), - ); - Ok(Self { - scopes, - client, - client_id, - application_base_url, - end_session_endpoint, - http_client: reqwest::Client::default(), - }) - } - /// create a new [`OidcClient`] from an existing [`ProviderMetadata`]. - pub fn from_provider_metadata_and_client( - provider_metadata: ProviderMetadata, - application_base_url: Uri, - client_id: String, - client_secret: Option, - scopes: Vec, - http_client: reqwest::Client, - ) -> Result { - let end_session_endpoint = provider_metadata - .additional_metadata() - .end_session_endpoint - .clone() - .map(Uri::from_maybe_shared) - .transpose() - .map_err(Error::InvalidEndSessionEndpoint)?; - let client = Client::from_provider_metadata( - provider_metadata, - ClientId::new(client_id.clone()), - client_secret.map(ClientSecret::new), - ); - Ok(Self { - scopes, - client, - client_id, - application_base_url, - end_session_endpoint, - http_client, - }) - } - - /// create a new [`OidcClient`] by fetching the required information from the - /// `/.well-known/openid-configuration` endpoint of the issuer. - pub async fn discover_new( - application_base_url: Uri, - issuer: String, - client_id: String, - client_secret: Option, - scopes: Vec, - ) -> Result { - let client = reqwest::Client::default(); - Self::discover_new_with_client( - application_base_url, - issuer, - client_id, - client_secret, - scopes, - client, - ) - .await - } - - /// create a new [`OidcClient`] by fetching the required information from the - /// `/.well-known/openid-configuration` endpoint of the issuer using the provided - /// `reqwest::Client`. - pub async fn discover_new_with_client( - application_base_url: Uri, - issuer: String, - client_id: String, - client_secret: Option, - scopes: Vec, - client: reqwest::Client, - ) -> Result { - let provider_metadata = - ProviderMetadata::discover_async(IssuerUrl::new(issuer)?, &client).await?; - Self::from_provider_metadata_and_client( - provider_metadata, - application_base_url, - client_id, - client_secret, - scopes, - client, - ) - } + auth_context_class: Option>, } /// an empty struct to be used as the default type for the additional claims generic diff --git a/src/middleware.rs b/src/middleware.rs index 5538108..f68e923 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -16,14 +16,15 @@ use tower_sessions::Session; use openidconnect::{ core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey}, - AccessToken, AccessTokenHash, AuthorizationCode, CsrfToken, IdTokenClaims, IdTokenVerifier, - Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, + AccessToken, AccessTokenHash, AuthenticationContextClass, AuthorizationCode, CsrfToken, + IdTokenClaims, IdTokenVerifier, Nonce, OAuth2TokenResponse, PkceCodeChallenge, + PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError::ServerResponse, Scope, TokenResponse, }; use crate::{ - error::{Error, MiddlewareError}, + error::MiddlewareError, extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}, AdditionalClaims, AuthenticatedSession, BoxError, ClearSessionFlag, IdToken, OidcClient, OidcQuery, OidcSession, SESSION_KEY, @@ -126,8 +127,11 @@ where .await .map_err(MiddlewareError::from)?; - let handler_uri = - strip_oidc_from_path(oidcclient.application_base_url.clone(), &parts.uri)?; + let handler_uri = strip_oidc_from_path( + oidcclient.application_base_url.clone(), + &parts.uri, + &oidcclient.oidc_request_parameters, + )?; oidcclient.client = oidcclient .client @@ -192,6 +196,12 @@ where auth = auth.add_scope(Scope::new(scope.to_string())); } + if let Some(acr) = oidcclient.auth_context_class { + auth = auth.add_auth_context_value(AuthenticationContextClass::new( + acr.into(), + )); + } + auth.set_pkce_challenge(pkce_challenge).url() }; @@ -225,24 +235,10 @@ impl OidcAuthLayer { pub fn new(client: OidcClient) -> Self { Self { client } } - - pub async fn discover_client( - application_base_url: Uri, - issuer: String, - client_id: String, - client_secret: Option, - scopes: Vec, - ) -> Result { - Ok(Self { - client: OidcClient::::discover_new( - application_base_url, - issuer, - client_id, - client_secret, - scopes, - ) - .await?, - }) +} +impl From> for OidcAuthLayer { + fn from(value: OidcClient) -> Self { + Self::new(value) } } @@ -309,8 +305,11 @@ where .await .map_err(MiddlewareError::from)?; - let handler_uri = - strip_oidc_from_path(oidcclient.application_base_url.clone(), &parts.uri)?; + let handler_uri = strip_oidc_from_path( + oidcclient.application_base_url.clone(), + &parts.uri, + &oidcclient.oidc_request_parameters, + )?; oidcclient.client = oidcclient .client @@ -373,7 +372,11 @@ where /// Helper function to remove the OpenID Connect authentication response query attributes from a /// [`Uri`]. -pub fn strip_oidc_from_path(base_url: Uri, uri: &Uri) -> Result { +pub fn strip_oidc_from_path( + base_url: Uri, + uri: &Uri, + filter: &[Box], +) -> Result { let mut base_url = base_url.into_parts(); base_url.path_and_query = uri @@ -381,20 +384,20 @@ pub fn strip_oidc_from_path(base_url: Uri, uri: &Uri) -> Result( .as_ref() .map(|end_session_endpoint| OidcRpInitiatedLogout { end_session_endpoint: end_session_endpoint.clone(), - id_token_hint: authenticated_session.id_token.to_string(), + id_token_hint: authenticated_session.id_token.to_string().into(), client_id: client.client_id.clone(), post_logout_redirect_uri: None, state: None,