feat: add typestate OidcClient builder

The previous generator functions for `OidcClient` have been replaced by
a Builder.
With this change the suggested changes by #14 and #21 have been
implemented.
This commit is contained in:
Paul Zinselmeyer 2025-02-18 21:26:56 +01:00
parent 6d7fc3c7f1
commit 58369449cf
Signed by: pfzetto
GPG key ID: B471A1AF06C895FD
6 changed files with 324 additions and 167 deletions

View file

@ -2,8 +2,8 @@ use axum::{
error_handling::HandleErrorLayer, http::Uri, response::IntoResponse, routing::get, Router,
};
use axum_oidc::{
error::MiddlewareError, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcLoginLayer,
OidcRpInitiatedLogout,
error::MiddlewareError, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcClient,
OidcLoginLayer, OidcRpInitiatedLogout,
};
use tokio::net::TcpListener;
use tower::ServiceBuilder;
@ -26,25 +26,30 @@ pub async fn run(
let oidc_login_service = ServiceBuilder::new()
.layer(HandleErrorLayer::new(|e: MiddlewareError| async {
dbg!(&e);
e.into_response()
}))
.layer(OidcLoginLayer::<EmptyAdditionalClaims>::new());
let mut oidc_client = OidcClient::<EmptyAdditionalClaims>::builder()
.with_default_http_client()
.with_application_base_url(Uri::from_maybe_shared(app_url).expect("valid APP_URL"))
.with_client_id(client_id);
if let Some(client_secret) = client_secret {
oidc_client = oidc_client.with_client_secret(client_secret);
}
let oidc_client = oidc_client
.discover(Uri::from_maybe_shared(issuer).expect("valid issuer URI"))
.await
.unwrap()
.build();
let oidc_auth_service = ServiceBuilder::new()
.layer(HandleErrorLayer::new(|e: MiddlewareError| async {
dbg!(&e);
e.into_response()
}))
.layer(
OidcAuthLayer::<EmptyAdditionalClaims>::discover_client(
Uri::from_maybe_shared(app_url).expect("valid APP_URL"),
issuer,
client_id,
client_secret,
vec![],
)
.await
.unwrap(),
);
.layer(OidcAuthLayer::new(oidc_client));
let app = Router::new()
.route("/foo", get(authenticated))
@ -79,5 +84,5 @@ async fn maybe_authenticated(
}
async fn logout(logout: OidcRpInitiatedLogout) -> impl IntoResponse {
logout.with_post_logout_redirect(Uri::from_static("https://pfzetto.de"))
logout.with_post_logout_redirect(Uri::from_static("https://example.com"))
}

255
src/builder.rs Normal file
View file

@ -0,0 +1,255 @@
use std::marker::PhantomData;
use http::Uri;
use openidconnect::{ClientId, ClientSecret, IssuerUrl};
use crate::{error::Error, AdditionalClaims, Client, OidcClient, ProviderMetadata};
pub struct Unconfigured;
pub struct ApplicationBaseUrl(Uri);
pub struct OpenidconnectClient<AC: AdditionalClaims>(crate::Client<AC>);
pub struct HttpClient(reqwest::Client);
pub struct ClientCredentials {
id: Box<str>,
secret: Option<Box<str>>,
}
pub struct Builder<AC: AdditionalClaims, ApplicationBaseUrl, Credentials, Client, HttpClient> {
application_base_url: ApplicationBaseUrl,
credentials: Credentials,
client: Client,
http_client: HttpClient,
end_session_endpoint: Option<Uri>,
scopes: Vec<Box<str>>,
oidc_request_parameters: Vec<Box<str>>,
auth_context_class: Option<Box<str>>,
_ac: PhantomData<AC>,
}
impl<AC: AdditionalClaims> Default for Builder<AC, (), (), (), ()> {
fn default() -> Self {
Self::new()
}
}
impl<AC: AdditionalClaims> Builder<AC, (), (), (), ()> {
/// create a new builder with default values
pub fn new() -> Self {
let oidc_request_parameters = ["code", "state", "session_state", "iss"]
.into_iter()
.map(Box::<str>::from)
.collect();
Self {
application_base_url: (),
credentials: (),
client: (),
http_client: (),
end_session_endpoint: None,
scopes: vec![Box::from("openid")],
oidc_request_parameters,
auth_context_class: None,
_ac: PhantomData,
}
}
}
impl<AC: AdditionalClaims> OidcClient<AC> {
/// create a new builder with default values
pub fn builder() -> Builder<AC, (), (), (), ()> {
Builder::<AC, (), (), (), ()>::new()
}
}
impl<AC: AdditionalClaims, APPBASE, CREDS, CLIENT, HTTP> Builder<AC, APPBASE, CREDS, CLIENT, HTTP> {
/// add a scope to existing (default) scopes
pub fn add_scope(mut self, scope: impl Into<Box<str>>) -> Self {
self.scopes.push(scope.into());
self
}
/// replace scopes (including default)
pub fn with_scopes(mut self, scopes: impl Iterator<Item = impl Into<Box<str>>>) -> Self {
self.scopes = scopes.map(|x| x.into()).collect::<Vec<_>>();
self
}
/// add a query parameter that will be filtered from requests to existing (default) filtered
/// query parameters
pub fn add_oidc_request_parameter(
mut self,
oidc_request_parameter: impl Into<Box<str>>,
) -> Self {
self.oidc_request_parameters
.push(oidc_request_parameter.into());
self
}
/// replace query parameters that will be filtered from requests (including default)
pub fn with_oidc_request_parameters(
mut self,
oidc_request_parameters: impl Iterator<Item = impl Into<Box<str>>>,
) -> Self {
self.oidc_request_parameters = oidc_request_parameters
.map(|x| x.into())
.collect::<Vec<_>>();
self
}
/// authenticate with Authentication Context Class Reference
pub fn with_auth_context_class(mut self, acr: impl Into<Box<str>>) -> Self {
self.auth_context_class = Some(acr.into());
self
}
}
impl<AC: AdditionalClaims, CREDS, CLIENT, HTTP> Builder<AC, (), CREDS, CLIENT, HTTP> {
/// set application base url (e.g. https://example.com)
pub fn with_application_base_url(
self,
url: impl Into<Uri>,
) -> Builder<AC, ApplicationBaseUrl, CREDS, CLIENT, HTTP> {
Builder {
application_base_url: ApplicationBaseUrl(url.into()),
credentials: self.credentials,
client: self.client,
http_client: self.http_client,
end_session_endpoint: self.end_session_endpoint,
scopes: self.scopes,
oidc_request_parameters: self.oidc_request_parameters,
auth_context_class: self.auth_context_class,
_ac: PhantomData,
}
}
}
impl<AC: AdditionalClaims, ABU, CLIENT, HTTP> Builder<AC, ABU, (), CLIENT, HTTP> {
/// set client id for authentication with issuer
pub fn with_client_id(
self,
id: impl Into<Box<str>>,
) -> Builder<AC, ABU, ClientCredentials, CLIENT, HTTP> {
Builder::<_, _, _, _, _> {
application_base_url: self.application_base_url,
credentials: ClientCredentials {
id: id.into(),
secret: None,
},
client: self.client,
http_client: self.http_client,
end_session_endpoint: self.end_session_endpoint,
scopes: self.scopes,
oidc_request_parameters: self.oidc_request_parameters,
auth_context_class: self.auth_context_class,
_ac: PhantomData,
}
}
}
impl<AC: AdditionalClaims, ABU, CLIENT, HTTP> Builder<AC, ABU, ClientCredentials, CLIENT, HTTP> {
/// set client secret for authentication with issuer
pub fn with_client_secret(mut self, secret: impl Into<Box<str>>) -> Self {
self.credentials.secret = Some(secret.into());
self
}
}
impl<AC: AdditionalClaims, ABU, CREDS, CLIENT> Builder<AC, ABU, CREDS, CLIENT, ()> {
/// use custom http client
pub fn with_http_client(
self,
client: reqwest::Client,
) -> Builder<AC, ABU, CREDS, CLIENT, HttpClient> {
Builder {
application_base_url: self.application_base_url,
credentials: self.credentials,
client: self.client,
http_client: HttpClient(client),
end_session_endpoint: self.end_session_endpoint,
scopes: self.scopes,
oidc_request_parameters: self.oidc_request_parameters,
auth_context_class: self.auth_context_class,
_ac: self._ac,
}
}
/// use default reqwest http client
pub fn with_default_http_client(self) -> Builder<AC, ABU, CREDS, CLIENT, HttpClient> {
Builder {
application_base_url: self.application_base_url,
credentials: self.credentials,
client: self.client,
http_client: HttpClient(reqwest::Client::default()),
end_session_endpoint: self.end_session_endpoint,
scopes: self.scopes,
oidc_request_parameters: self.oidc_request_parameters,
auth_context_class: self.auth_context_class,
_ac: self._ac,
}
}
}
impl<AC: AdditionalClaims, ABU> Builder<AC, ABU, ClientCredentials, (), HttpClient> {
/// provide issuer details manually
pub fn manual(
self,
provider_metadata: ProviderMetadata,
) -> Result<Builder<AC, ABU, ClientCredentials, OpenidconnectClient<AC>, HttpClient>, Error>
{
let end_session_endpoint = provider_metadata
.additional_metadata()
.end_session_endpoint
.clone()
.map(Uri::from_maybe_shared)
.transpose()
.map_err(Error::InvalidEndSessionEndpoint)?;
let client = Client::from_provider_metadata(
provider_metadata,
ClientId::new(self.credentials.id.to_string()),
self.credentials
.secret
.as_ref()
.map(|x| ClientSecret::new(x.to_string())),
);
Ok(Builder {
application_base_url: self.application_base_url,
credentials: self.credentials,
client: OpenidconnectClient(client),
http_client: self.http_client,
end_session_endpoint,
scopes: self.scopes,
oidc_request_parameters: self.oidc_request_parameters,
auth_context_class: self.auth_context_class,
_ac: self._ac,
})
}
/// discover issuer details
pub async fn discover(
self,
issuer: impl Into<Uri>,
) -> Result<Builder<AC, ABU, ClientCredentials, OpenidconnectClient<AC>, HttpClient>, Error>
{
let issuer_url = IssuerUrl::new(issuer.into().to_string())?;
let http_client = self.http_client.0.clone();
let provider_metadata = ProviderMetadata::discover_async(issuer_url, &http_client);
Self::manual(self, provider_metadata.await?)
}
}
impl<AC: AdditionalClaims>
Builder<AC, ApplicationBaseUrl, ClientCredentials, OpenidconnectClient<AC>, HttpClient>
{
/// create oidc client
pub fn build(self) -> OidcClient<AC> {
OidcClient {
scopes: self.scopes,
oidc_request_parameters: self.oidc_request_parameters,
client_id: self.credentials.id,
client: self.client.0,
http_client: self.http_client.0,
application_base_url: self.application_base_url.0,
end_session_endpoint: self.end_session_endpoint,
auth_context_class: self.auth_context_class,
}
}
}

View file

@ -111,7 +111,6 @@ impl IntoResponse for ExtractorError {
impl IntoResponse for Error {
fn into_response(self) -> axum_core::response::Response {
dbg!(&self);
match self {
_ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(),
}
@ -120,7 +119,6 @@ impl IntoResponse for Error {
impl IntoResponse for MiddlewareError {
fn into_response(self) -> axum_core::response::Response {
dbg!(&self);
match self {
_ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(),
}

View file

@ -112,8 +112,8 @@ impl AsRef<str> for OidcAccessToken {
#[derive(Clone)]
pub struct OidcRpInitiatedLogout {
pub(crate) end_session_endpoint: Uri,
pub(crate) id_token_hint: String,
pub(crate) client_id: String,
pub(crate) id_token_hint: Box<str>,
pub(crate) client_id: Box<str>,
pub(crate) post_logout_redirect_uri: Option<Uri>,
pub(crate) state: Option<String>,
}

View file

@ -3,7 +3,6 @@
#![deny(warnings)]
#![doc = include_str!("../README.md")]
use crate::error::Error;
use http::Uri;
use openidconnect::{
core::{
@ -13,12 +12,13 @@ use openidconnect::{
CoreResponseMode, CoreResponseType, CoreRevocableToken, CoreRevocationErrorResponse,
CoreSubjectIdentifierType, CoreTokenIntrospectionResponse, CoreTokenType,
},
AccessToken, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, EndpointMaybeSet,
EndpointNotSet, EndpointSet, IdTokenFields, IssuerUrl, Nonce, PkceCodeVerifier, RefreshToken,
StandardErrorResponse, StandardTokenResponse,
AccessToken, CsrfToken, EmptyExtraTokenFields, EndpointMaybeSet, EndpointNotSet, EndpointSet,
IdTokenFields, Nonce, PkceCodeVerifier, RefreshToken, StandardErrorResponse,
StandardTokenResponse,
};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
pub mod builder;
pub mod error;
mod extractor;
mod middleware;
@ -99,118 +99,14 @@ pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
/// OpenID Connect Client
#[derive(Clone)]
pub struct OidcClient<AC: AdditionalClaims> {
scopes: Vec<String>,
client_id: String,
scopes: Vec<Box<str>>,
oidc_request_parameters: Vec<Box<str>>,
client_id: Box<str>,
client: Client<AC>,
http_client: reqwest::Client,
application_base_url: Uri,
end_session_endpoint: Option<Uri>,
}
impl<AC: AdditionalClaims> OidcClient<AC> {
/// create a new [`OidcClient`] from an existing [`ProviderMetadata`].
pub fn from_provider_metadata(
provider_metadata: ProviderMetadata,
application_base_url: Uri,
client_id: String,
client_secret: Option<String>,
scopes: Vec<String>,
) -> Result<Self, Error> {
let end_session_endpoint = provider_metadata
.additional_metadata()
.end_session_endpoint
.clone()
.map(Uri::from_maybe_shared)
.transpose()
.map_err(Error::InvalidEndSessionEndpoint)?;
let client = Client::from_provider_metadata(
provider_metadata,
ClientId::new(client_id.clone()),
client_secret.map(ClientSecret::new),
);
Ok(Self {
scopes,
client,
client_id,
application_base_url,
end_session_endpoint,
http_client: reqwest::Client::default(),
})
}
/// create a new [`OidcClient`] from an existing [`ProviderMetadata`].
pub fn from_provider_metadata_and_client(
provider_metadata: ProviderMetadata,
application_base_url: Uri,
client_id: String,
client_secret: Option<String>,
scopes: Vec<String>,
http_client: reqwest::Client,
) -> Result<Self, Error> {
let end_session_endpoint = provider_metadata
.additional_metadata()
.end_session_endpoint
.clone()
.map(Uri::from_maybe_shared)
.transpose()
.map_err(Error::InvalidEndSessionEndpoint)?;
let client = Client::from_provider_metadata(
provider_metadata,
ClientId::new(client_id.clone()),
client_secret.map(ClientSecret::new),
);
Ok(Self {
scopes,
client,
client_id,
application_base_url,
end_session_endpoint,
http_client,
})
}
/// create a new [`OidcClient`] by fetching the required information from the
/// `/.well-known/openid-configuration` endpoint of the issuer.
pub async fn discover_new(
application_base_url: Uri,
issuer: String,
client_id: String,
client_secret: Option<String>,
scopes: Vec<String>,
) -> Result<Self, Error> {
let client = reqwest::Client::default();
Self::discover_new_with_client(
application_base_url,
issuer,
client_id,
client_secret,
scopes,
client,
)
.await
}
/// create a new [`OidcClient`] by fetching the required information from the
/// `/.well-known/openid-configuration` endpoint of the issuer using the provided
/// `reqwest::Client`.
pub async fn discover_new_with_client(
application_base_url: Uri,
issuer: String,
client_id: String,
client_secret: Option<String>,
scopes: Vec<String>,
client: reqwest::Client,
) -> Result<Self, Error> {
let provider_metadata =
ProviderMetadata::discover_async(IssuerUrl::new(issuer)?, &client).await?;
Self::from_provider_metadata_and_client(
provider_metadata,
application_base_url,
client_id,
client_secret,
scopes,
client,
)
}
auth_context_class: Option<Box<str>>,
}
/// an empty struct to be used as the default type for the additional claims generic

View file

@ -16,14 +16,15 @@ use tower_sessions::Session;
use openidconnect::{
core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey},
AccessToken, AccessTokenHash, AuthorizationCode, CsrfToken, IdTokenClaims, IdTokenVerifier,
Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken,
AccessToken, AccessTokenHash, AuthenticationContextClass, AuthorizationCode, CsrfToken,
IdTokenClaims, IdTokenVerifier, Nonce, OAuth2TokenResponse, PkceCodeChallenge,
PkceCodeVerifier, RedirectUrl, RefreshToken,
RequestTokenError::ServerResponse,
Scope, TokenResponse,
};
use crate::{
error::{Error, MiddlewareError},
error::MiddlewareError,
extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout},
AdditionalClaims, AuthenticatedSession, BoxError, ClearSessionFlag, IdToken, OidcClient,
OidcQuery, OidcSession, SESSION_KEY,
@ -126,8 +127,11 @@ where
.await
.map_err(MiddlewareError::from)?;
let handler_uri =
strip_oidc_from_path(oidcclient.application_base_url.clone(), &parts.uri)?;
let handler_uri = strip_oidc_from_path(
oidcclient.application_base_url.clone(),
&parts.uri,
&oidcclient.oidc_request_parameters,
)?;
oidcclient.client = oidcclient
.client
@ -192,6 +196,12 @@ where
auth = auth.add_scope(Scope::new(scope.to_string()));
}
if let Some(acr) = oidcclient.auth_context_class {
auth = auth.add_auth_context_value(AuthenticationContextClass::new(
acr.into(),
));
}
auth.set_pkce_challenge(pkce_challenge).url()
};
@ -225,24 +235,10 @@ impl<AC: AdditionalClaims> OidcAuthLayer<AC> {
pub fn new(client: OidcClient<AC>) -> Self {
Self { client }
}
pub async fn discover_client(
application_base_url: Uri,
issuer: String,
client_id: String,
client_secret: Option<String>,
scopes: Vec<String>,
) -> Result<Self, Error> {
Ok(Self {
client: OidcClient::<AC>::discover_new(
application_base_url,
issuer,
client_id,
client_secret,
scopes,
)
.await?,
})
}
impl<AC: AdditionalClaims> From<OidcClient<AC>> for OidcAuthLayer<AC> {
fn from(value: OidcClient<AC>) -> Self {
Self::new(value)
}
}
@ -309,8 +305,11 @@ where
.await
.map_err(MiddlewareError::from)?;
let handler_uri =
strip_oidc_from_path(oidcclient.application_base_url.clone(), &parts.uri)?;
let handler_uri = strip_oidc_from_path(
oidcclient.application_base_url.clone(),
&parts.uri,
&oidcclient.oidc_request_parameters,
)?;
oidcclient.client = oidcclient
.client
@ -373,7 +372,11 @@ where
/// Helper function to remove the OpenID Connect authentication response query attributes from a
/// [`Uri`].
pub fn strip_oidc_from_path(base_url: Uri, uri: &Uri) -> Result<Uri, MiddlewareError> {
pub fn strip_oidc_from_path(
base_url: Uri,
uri: &Uri,
filter: &[Box<str>],
) -> Result<Uri, MiddlewareError> {
let mut base_url = base_url.into_parts();
base_url.path_and_query = uri
@ -381,20 +384,20 @@ pub fn strip_oidc_from_path(base_url: Uri, uri: &Uri) -> Result<Uri, MiddlewareE
.map(|path_and_query| {
let query = path_and_query
.query()
.and_then(|uri| {
.map(|uri| {
uri.split('&')
.filter(|x| {
!x.starts_with("code")
&& !x.starts_with("state")
&& !x.starts_with("session_state")
&& !x.starts_with("iss")
.filter(|x| filter.iter().all(|y| !x.starts_with(y.as_ref())))
.fold(String::default(), |mut acc, x| {
if !acc.is_empty() {
acc += "&";
} else {
acc += "?";
}
acc += x;
acc
})
.map(|x| x.to_string())
.reduce(|acc, x| acc + "&" + &x)
})
.map(|x| format!("?{x}"))
.unwrap_or_default();
PathAndQuery::from_maybe_shared(format!("{}{}", path_and_query.path(), query))
})
.transpose()?;
@ -418,7 +421,7 @@ fn insert_extensions<AC: AdditionalClaims>(
.as_ref()
.map(|end_session_endpoint| OidcRpInitiatedLogout {
end_session_endpoint: end_session_endpoint.clone(),
id_token_hint: authenticated_session.id_token.to_string(),
id_token_hint: authenticated_session.id_token.to_string().into(),
client_id: client.client_id.clone(),
post_logout_redirect_uri: None,
state: None,