feat(builder): Add builder pattern for OidcClient

Signed-off-by: MATILLAT Quentin <qmatillat@gmail.com>
This commit is contained in:
MATILLAT Quentin 2025-02-17 11:57:10 +01:00
parent f0d9126652
commit 352a4dc187
No known key found for this signature in database
GPG key ID: B9BAF56E288158D2
2 changed files with 117 additions and 84 deletions

View file

@ -15,6 +15,7 @@ keywords = [ "axum", "oidc", "openidconnect", "authentication" ]
thiserror = "2.0" thiserror = "2.0"
axum-core = "0.5" axum-core = "0.5"
axum = { version = "0.8", default-features = false, features = [ "query" ] } axum = { version = "0.8", default-features = false, features = [ "query" ] }
bon = "3.3"
tower-service = "0.3" tower-service = "0.3"
tower-layer = "0.3" tower-layer = "0.3"
tower-sessions = { version = "0.14", default-features = false, features = [ "axum-core" ] } tower-sessions = { version = "0.14", default-features = false, features = [ "axum-core" ] }

View file

@ -5,6 +5,7 @@
use crate::error::Error; use crate::error::Error;
use http::Uri; use http::Uri;
use oidc_client_builder::{IsUnset, SetClient, SetClientId, SetEndSessionEndpoint, State};
use openidconnect::{ use openidconnect::{
core::{ core::{
CoreAuthDisplay, CoreAuthPrompt, CoreClaimName, CoreClaimType, CoreClientAuthMethod, CoreAuthDisplay, CoreAuthPrompt, CoreClaimName, CoreClaimType, CoreClientAuthMethod,
@ -93,16 +94,109 @@ pub type ProviderMetadata = openidconnect::ProviderMetadata<
pub type BoxError = Box<dyn std::error::Error + Send + Sync>; pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
/// OpenID Connect Client /// OpenID Connect Client
#[derive(Clone)] #[derive(Clone, bon::Builder)]
pub struct OidcClient<AC: AdditionalClaims> { pub struct OidcClient<AC: AdditionalClaims> {
scopes: Vec<String>, #[builder(field)]
client_id: String,
client: Client<AC>,
http_client: reqwest::Client, http_client: reqwest::Client,
scopes: Vec<String>,
client: Client<AC>,
client_id: String,
application_base_url: Uri, application_base_url: Uri,
end_session_endpoint: Option<Uri>, end_session_endpoint: Option<Uri>,
} }
type SetFinal<AC, S> = OidcClientBuilder<AC, SetClientId<SetClient<SetEndSessionEndpoint<S>>>>;
impl<AC: AdditionalClaims, S: State> OidcClientBuilder<AC, S> {
pub fn http_client(mut self, http_client: reqwest::Client) -> Self {
self.http_client = http_client;
self
}
/// set `end_session_endpoint` and initialize a `client` from an existing [`ProviderMetadata`].
pub fn with_provider_metadata(
self,
provider_metadata: ProviderMetadata,
client_id: String,
client_secret: Option<String>,
) -> Result<SetFinal<AC, S>, Error>
where
S::EndSessionEndpoint: IsUnset,
S::Client: IsUnset,
S::ClientId: IsUnset,
{
let end_session_endpoint = provider_metadata
.additional_metadata()
.end_session_endpoint
.clone()
.map(Uri::from_maybe_shared)
.transpose()
.map_err(Error::InvalidEndSessionEndpoint)?;
let client = Client::from_provider_metadata(
provider_metadata,
ClientId::new(client_id.clone()),
client_secret.map(ClientSecret::new),
);
Ok(self
.maybe_end_session_endpoint(end_session_endpoint)
.client(client)
.client_id(client_id))
}
/// create a new [`OidcClient`] by fetching the required information from the
/// `/.well-known/openid-configuration` endpoint of the issuer.
pub async fn discover_new_with_client(
self,
issuer: String,
client_id: String,
client_secret: Option<String>,
) -> Result<SetFinal<AC, S>, Error>
where
S::EndSessionEndpoint: IsUnset,
S::Client: IsUnset,
S::ClientId: IsUnset,
{
let http_client = &self.http_client.clone();
// modified version of `openidconnect::reqwest::async_client::async_http_client`.
let async_http_client = |request: HttpRequest| async move {
let mut request_builder = http_client
.request(request.method, request.url.as_str())
.body(request.body);
for (name, value) in &request.headers {
request_builder = request_builder.header(name.as_str(), value.as_bytes());
}
let request = request_builder
.build()
.map_err(openidconnect::reqwest::Error::Reqwest)?;
let response = http_client
.execute(request)
.await
.map_err(openidconnect::reqwest::Error::Reqwest)?;
let status_code = response.status();
let headers = response.headers().to_owned();
let chunks = response
.bytes()
.await
.map_err(openidconnect::reqwest::Error::Reqwest)?;
Ok(HttpResponse {
status_code,
headers,
body: chunks.to_vec(),
})
};
let provider_metadata =
ProviderMetadata::discover_async(IssuerUrl::new(issuer)?, async_http_client).await?;
self.with_provider_metadata(provider_metadata, client_id, client_secret)
}
}
impl<AC: AdditionalClaims> OidcClient<AC> { impl<AC: AdditionalClaims> OidcClient<AC> {
/// create a new [`OidcClient`] from an existing [`ProviderMetadata`]. /// create a new [`OidcClient`] from an existing [`ProviderMetadata`].
pub fn from_provider_metadata( pub fn from_provider_metadata(
@ -112,26 +206,11 @@ impl<AC: AdditionalClaims> OidcClient<AC> {
client_secret: Option<String>, client_secret: Option<String>,
scopes: Vec<String>, scopes: Vec<String>,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
let end_session_endpoint = provider_metadata Ok(Self::builder()
.additional_metadata() .scopes(scopes)
.end_session_endpoint .application_base_url(application_base_url)
.clone() .with_provider_metadata(provider_metadata, client_id, client_secret)?
.map(Uri::from_maybe_shared) .build())
.transpose()
.map_err(Error::InvalidEndSessionEndpoint)?;
let client = Client::from_provider_metadata(
provider_metadata,
ClientId::new(client_id.clone()),
client_secret.map(ClientSecret::new),
);
Ok(Self {
scopes,
client,
client_id,
application_base_url,
end_session_endpoint,
http_client: reqwest::Client::default(),
})
} }
/// create a new [`OidcClient`] from an existing [`ProviderMetadata`]. /// create a new [`OidcClient`] from an existing [`ProviderMetadata`].
pub fn from_provider_metadata_and_client( pub fn from_provider_metadata_and_client(
@ -142,26 +221,12 @@ impl<AC: AdditionalClaims> OidcClient<AC> {
scopes: Vec<String>, scopes: Vec<String>,
http_client: reqwest::Client, http_client: reqwest::Client,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
let end_session_endpoint = provider_metadata Ok(Self::builder()
.additional_metadata() .http_client(http_client)
.end_session_endpoint .scopes(scopes)
.clone() .application_base_url(application_base_url)
.map(Uri::from_maybe_shared) .with_provider_metadata(provider_metadata, client_id, client_secret)?
.transpose() .build())
.map_err(Error::InvalidEndSessionEndpoint)?;
let client = Client::from_provider_metadata(
provider_metadata,
ClientId::new(client_id.clone()),
client_secret.map(ClientSecret::new),
);
Ok(Self {
scopes,
client,
client_id,
application_base_url,
end_session_endpoint,
http_client,
})
} }
/// create a new [`OidcClient`] by fetching the required information from the /// create a new [`OidcClient`] by fetching the required information from the
@ -197,46 +262,13 @@ impl<AC: AdditionalClaims> OidcClient<AC> {
//TODO remove borrow with next breaking version //TODO remove borrow with next breaking version
client: &reqwest::Client, client: &reqwest::Client,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
// modified version of `openidconnect::reqwest::async_client::async_http_client`. Ok(Self::builder()
let async_http_client = |request: HttpRequest| async move { .http_client(client.clone())
let mut request_builder = client .scopes(scopes)
.request(request.method, request.url.as_str()) .application_base_url(application_base_url)
.body(request.body); .discover_new_with_client(issuer, client_id, client_secret)
for (name, value) in &request.headers { .await?
request_builder = request_builder.header(name.as_str(), value.as_bytes()); .build())
}
let request = request_builder
.build()
.map_err(openidconnect::reqwest::Error::Reqwest)?;
let response = client
.execute(request)
.await
.map_err(openidconnect::reqwest::Error::Reqwest)?;
let status_code = response.status();
let headers = response.headers().to_owned();
let chunks = response
.bytes()
.await
.map_err(openidconnect::reqwest::Error::Reqwest)?;
Ok(HttpResponse {
status_code,
headers,
body: chunks.to_vec(),
})
};
let provider_metadata =
ProviderMetadata::discover_async(IssuerUrl::new(issuer)?, async_http_client).await?;
Self::from_provider_metadata_and_client(
provider_metadata,
application_base_url,
client_id,
client_secret,
scopes,
client.clone(),
)
} }
} }