fix: use custom reqwest::Client in middleware

previously the middlewares would use the default `reqwest::Client` even
if the `OidcClient` is created with a custom client.
Now the middleware uses the `reqwest::Client` used during creation of
`OidcClient`.
This commit is contained in:
Paul Zinselmeyer 2024-09-06 20:53:12 +02:00
parent 202b61fa83
commit 9dd85a7703
Signed by: pfzetto
GPG key ID: B471A1AF06C895FD
2 changed files with 95 additions and 12 deletions

View file

@ -98,6 +98,7 @@ pub struct OidcClient<AC: AdditionalClaims> {
scopes: Vec<String>, scopes: Vec<String>,
client_id: String, client_id: String,
client: Client<AC>, client: Client<AC>,
http_client: reqwest::Client,
application_base_url: Uri, application_base_url: Uri,
end_session_endpoint: Option<Uri>, end_session_endpoint: Option<Uri>,
} }
@ -129,6 +130,37 @@ impl<AC: AdditionalClaims> OidcClient<AC> {
client_id, client_id,
application_base_url, application_base_url,
end_session_endpoint, end_session_endpoint,
http_client: reqwest::Client::default(),
})
}
/// create a new [`OidcClient`] from an existing [`ProviderMetadata`].
pub fn from_provider_metadata_and_client(
provider_metadata: ProviderMetadata,
application_base_url: Uri,
client_id: String,
client_secret: Option<String>,
scopes: Vec<String>,
http_client: reqwest::Client,
) -> Result<Self, Error> {
let end_session_endpoint = provider_metadata
.additional_metadata()
.end_session_endpoint
.clone()
.map(Uri::from_maybe_shared)
.transpose()
.map_err(Error::InvalidEndSessionEndpoint)?;
let client = Client::from_provider_metadata(
provider_metadata,
ClientId::new(client_id.clone()),
client_secret.map(ClientSecret::new),
);
Ok(Self {
scopes,
client,
client_id,
application_base_url,
end_session_endpoint,
http_client,
}) })
} }
@ -162,6 +194,7 @@ impl<AC: AdditionalClaims> OidcClient<AC> {
client_id: String, client_id: String,
client_secret: Option<String>, client_secret: Option<String>,
scopes: Vec<String>, scopes: Vec<String>,
//TODO remove borrow with next breaking version
client: &reqwest::Client, client: &reqwest::Client,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
// modified version of `openidconnect::reqwest::async_client::async_http_client`. // modified version of `openidconnect::reqwest::async_client::async_http_client`.
@ -196,12 +229,13 @@ impl<AC: AdditionalClaims> OidcClient<AC> {
let provider_metadata = let provider_metadata =
ProviderMetadata::discover_async(IssuerUrl::new(issuer)?, async_http_client).await?; ProviderMetadata::discover_async(IssuerUrl::new(issuer)?, async_http_client).await?;
Self::from_provider_metadata( Self::from_provider_metadata_and_client(
provider_metadata, provider_metadata,
application_base_url, application_base_url,
client_id, client_id,
client_secret, client_secret,
scopes, scopes,
client.clone(),
) )
} }
} }

View file

@ -1,5 +1,6 @@
use std::{ use std::{
marker::PhantomData, marker::PhantomData,
pin::Pin,
task::{Context, Poll}, task::{Context, Poll},
}; };
@ -8,7 +9,7 @@ use axum::{
response::{IntoResponse, Redirect}, response::{IntoResponse, Redirect},
}; };
use axum_core::{extract::FromRequestParts, response::Response}; use axum_core::{extract::FromRequestParts, response::Response};
use futures_util::future::BoxFuture; use futures_util::{future::BoxFuture, Future};
use http::{request::Parts, uri::PathAndQuery, Request, Uri}; use http::{request::Parts, uri::PathAndQuery, Request, Uri};
use tower_layer::Layer; use tower_layer::Layer;
use tower_service::Service; use tower_service::Service;
@ -16,9 +17,9 @@ use tower_sessions::Session;
use openidconnect::{ use openidconnect::{
core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim}, core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim},
reqwest::async_http_client, AccessToken, AccessTokenHash, AuthorizationCode, CsrfToken, HttpRequest, HttpResponse,
AccessToken, AccessTokenHash, AuthorizationCode, CsrfToken, IdTokenClaims, Nonce, IdTokenClaims, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl,
OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, RefreshToken,
RequestTokenError::ServerResponse, RequestTokenError::ServerResponse,
Scope, TokenResponse, Scope, TokenResponse,
}; };
@ -149,7 +150,7 @@ where
.set_pkce_verifier(PkceCodeVerifier::new( .set_pkce_verifier(PkceCodeVerifier::new(
login_session.pkce_verifier.secret().to_string(), login_session.pkce_verifier.secret().to_string(),
)) ))
.request_async(async_http_client) .request_async(async_http_client(&oidcclient.http_client))
.await?; .await?;
// Extract the ID token claims after verifying its authenticity and nonce. // Extract the ID token claims after verifying its authenticity and nonce.
@ -409,15 +410,16 @@ fn insert_extensions<AC: AdditionalClaims>(
parts.extensions.insert(OidcAccessToken( parts.extensions.insert(OidcAccessToken(
authenticated_session.access_token.secret().to_string(), authenticated_session.access_token.secret().to_string(),
)); ));
let rp_initiated_logout = client.end_session_endpoint.as_ref().map(|end_session_endpoint| let rp_initiated_logout = client
OidcRpInitiatedLogout { .end_session_endpoint
.as_ref()
.map(|end_session_endpoint| OidcRpInitiatedLogout {
end_session_endpoint: end_session_endpoint.clone(), end_session_endpoint: end_session_endpoint.clone(),
id_token_hint: authenticated_session.id_token.to_string(), id_token_hint: authenticated_session.id_token.to_string(),
client_id: client.client_id.clone(), client_id: client.client_id.clone(),
post_logout_redirect_uri: None, post_logout_redirect_uri: None,
state: None, state: None,
} });
);
parts.extensions.insert(rp_initiated_logout); parts.extensions.insert(rp_initiated_logout);
} }
@ -460,7 +462,10 @@ async fn try_refresh_token<AC: AdditionalClaims>(
refresh_request = refresh_request.add_scope(Scope::new(scope.to_string())); refresh_request = refresh_request.add_scope(Scope::new(scope.to_string()));
} }
match refresh_request.request_async(async_http_client).await { match refresh_request
.request_async(async_http_client(&client.http_client))
.await
{
Ok(token_response) => { Ok(token_response) => {
// Extract the ID token claims after verifying its authenticity and nonce. // Extract the ID token claims after verifying its authenticity and nonce.
let id_token = token_response let id_token = token_response
@ -489,3 +494,47 @@ async fn try_refresh_token<AC: AdditionalClaims>(
Err(err) => Err(err.into()), Err(err) => Err(err.into()),
} }
} }
/// `openidconnect::reqwest::async_http_client` that uses a custom `reqwest::client`
fn async_http_client<'a>(
client: &'a reqwest::Client,
) -> impl FnOnce(
HttpRequest,
) -> Pin<
Box<
dyn Future<Output = Result<HttpResponse, openidconnect::reqwest::Error<reqwest::Error>>>
+ Send
+ 'a,
>,
> {
move |request: HttpRequest| {
Box::pin(async move {
let mut request_builder = client
.request(request.method, request.url.as_str())
.body(request.body);
for (name, value) in &request.headers {
request_builder = request_builder.header(name.as_str(), value.as_bytes());
}
let request = request_builder
.build()
.map_err(openidconnect::reqwest::Error::Reqwest)?;
let response = client
.execute(request)
.await
.map_err(openidconnect::reqwest::Error::Reqwest)?;
let status_code = response.status();
let headers = response.headers().to_owned();
let chunks = response
.bytes()
.await
.map_err(openidconnect::reqwest::Error::Reqwest)?;
Ok(HttpResponse {
status_code,
headers,
body: chunks.to_vec(),
})
})
}
}