diff --git a/Cargo.toml b/Cargo.toml index 5cd18f0..1f3b7b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ keywords = [ "axum", "oidc", "openidconnect", "authentication" ] thiserror = "2.0" axum-core = "0.5" axum = { version = "0.8", default-features = false, features = [ "query" ] } +bon = "3.3" tower-service = "0.3" tower-layer = "0.3" tower-sessions = { version = "0.14", default-features = false, features = [ "axum-core" ] } diff --git a/src/lib.rs b/src/lib.rs index a6825b3..3593114 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ use crate::error::Error; use http::Uri; +use oidc_client_builder::{IsUnset, SetClient, SetClientId, SetEndSessionEndpoint, State}; use openidconnect::{ core::{ CoreAuthDisplay, CoreAuthPrompt, CoreClaimName, CoreClaimType, CoreClientAuthMethod, @@ -93,16 +94,109 @@ pub type ProviderMetadata = openidconnect::ProviderMetadata< pub type BoxError = Box; /// OpenID Connect Client -#[derive(Clone)] +#[derive(Clone, bon::Builder)] pub struct OidcClient { - scopes: Vec, - client_id: String, - client: Client, + #[builder(field)] http_client: reqwest::Client, + + scopes: Vec, + client: Client, + client_id: String, application_base_url: Uri, end_session_endpoint: Option, } +type SetFinal = OidcClientBuilder>>>; + +impl OidcClientBuilder { + pub fn http_client(mut self, http_client: reqwest::Client) -> Self { + self.http_client = http_client; + self + } + + /// set `end_session_endpoint` and initialize a `client` from an existing [`ProviderMetadata`]. + pub fn with_provider_metadata( + self, + provider_metadata: ProviderMetadata, + client_id: String, + client_secret: Option, + ) -> Result, Error> + where + S::EndSessionEndpoint: IsUnset, + S::Client: IsUnset, + S::ClientId: IsUnset, + { + let end_session_endpoint = provider_metadata + .additional_metadata() + .end_session_endpoint + .clone() + .map(Uri::from_maybe_shared) + .transpose() + .map_err(Error::InvalidEndSessionEndpoint)?; + let client = Client::from_provider_metadata( + provider_metadata, + ClientId::new(client_id.clone()), + client_secret.map(ClientSecret::new), + ); + + Ok(self + .maybe_end_session_endpoint(end_session_endpoint) + .client(client) + .client_id(client_id)) + } + + /// create a new [`OidcClient`] by fetching the required information from the + /// `/.well-known/openid-configuration` endpoint of the issuer. + pub async fn discover_new_with_client( + self, + issuer: String, + client_id: String, + client_secret: Option, + ) -> Result, Error> + where + S::EndSessionEndpoint: IsUnset, + S::Client: IsUnset, + S::ClientId: IsUnset, + { + let http_client = &self.http_client.clone(); + + // modified version of `openidconnect::reqwest::async_client::async_http_client`. + let async_http_client = |request: HttpRequest| async move { + let mut request_builder = http_client + .request(request.method, request.url.as_str()) + .body(request.body); + for (name, value) in &request.headers { + request_builder = request_builder.header(name.as_str(), value.as_bytes()); + } + let request = request_builder + .build() + .map_err(openidconnect::reqwest::Error::Reqwest)?; + + let response = http_client + .execute(request) + .await + .map_err(openidconnect::reqwest::Error::Reqwest)?; + + let status_code = response.status(); + let headers = response.headers().to_owned(); + let chunks = response + .bytes() + .await + .map_err(openidconnect::reqwest::Error::Reqwest)?; + Ok(HttpResponse { + status_code, + headers, + body: chunks.to_vec(), + }) + }; + + let provider_metadata = + ProviderMetadata::discover_async(IssuerUrl::new(issuer)?, async_http_client).await?; + + self.with_provider_metadata(provider_metadata, client_id, client_secret) + } +} + impl OidcClient { /// create a new [`OidcClient`] from an existing [`ProviderMetadata`]. pub fn from_provider_metadata( @@ -112,26 +206,11 @@ impl OidcClient { client_secret: Option, scopes: Vec, ) -> Result { - let end_session_endpoint = provider_metadata - .additional_metadata() - .end_session_endpoint - .clone() - .map(Uri::from_maybe_shared) - .transpose() - .map_err(Error::InvalidEndSessionEndpoint)?; - let client = Client::from_provider_metadata( - provider_metadata, - ClientId::new(client_id.clone()), - client_secret.map(ClientSecret::new), - ); - Ok(Self { - scopes, - client, - client_id, - application_base_url, - end_session_endpoint, - http_client: reqwest::Client::default(), - }) + Ok(Self::builder() + .scopes(scopes) + .application_base_url(application_base_url) + .with_provider_metadata(provider_metadata, client_id, client_secret)? + .build()) } /// create a new [`OidcClient`] from an existing [`ProviderMetadata`]. pub fn from_provider_metadata_and_client( @@ -142,26 +221,12 @@ impl OidcClient { scopes: Vec, http_client: reqwest::Client, ) -> Result { - let end_session_endpoint = provider_metadata - .additional_metadata() - .end_session_endpoint - .clone() - .map(Uri::from_maybe_shared) - .transpose() - .map_err(Error::InvalidEndSessionEndpoint)?; - let client = Client::from_provider_metadata( - provider_metadata, - ClientId::new(client_id.clone()), - client_secret.map(ClientSecret::new), - ); - Ok(Self { - scopes, - client, - client_id, - application_base_url, - end_session_endpoint, - http_client, - }) + Ok(Self::builder() + .http_client(http_client) + .scopes(scopes) + .application_base_url(application_base_url) + .with_provider_metadata(provider_metadata, client_id, client_secret)? + .build()) } /// create a new [`OidcClient`] by fetching the required information from the @@ -197,46 +262,13 @@ impl OidcClient { //TODO remove borrow with next breaking version client: &reqwest::Client, ) -> Result { - // modified version of `openidconnect::reqwest::async_client::async_http_client`. - let async_http_client = |request: HttpRequest| async move { - let mut request_builder = client - .request(request.method, request.url.as_str()) - .body(request.body); - for (name, value) in &request.headers { - request_builder = request_builder.header(name.as_str(), value.as_bytes()); - } - let request = request_builder - .build() - .map_err(openidconnect::reqwest::Error::Reqwest)?; - - let response = client - .execute(request) - .await - .map_err(openidconnect::reqwest::Error::Reqwest)?; - - let status_code = response.status(); - let headers = response.headers().to_owned(); - let chunks = response - .bytes() - .await - .map_err(openidconnect::reqwest::Error::Reqwest)?; - Ok(HttpResponse { - status_code, - headers, - body: chunks.to_vec(), - }) - }; - - let provider_metadata = - ProviderMetadata::discover_async(IssuerUrl::new(issuer)?, async_http_client).await?; - Self::from_provider_metadata_and_client( - provider_metadata, - application_base_url, - client_id, - client_secret, - scopes, - client.clone(), - ) + Ok(Self::builder() + .http_client(client.clone()) + .scopes(scopes) + .application_base_url(application_base_url) + .discover_new_with_client(issuer, client_id, client_secret) + .await? + .build()) } }