diff --git a/src/lib.rs b/src/lib.rs index ae32f57..a6825b3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -98,6 +98,7 @@ pub struct OidcClient { scopes: Vec, client_id: String, client: Client, + http_client: reqwest::Client, application_base_url: Uri, end_session_endpoint: Option, } @@ -129,6 +130,37 @@ impl OidcClient { client_id, application_base_url, end_session_endpoint, + http_client: reqwest::Client::default(), + }) + } + /// create a new [`OidcClient`] from an existing [`ProviderMetadata`]. + pub fn from_provider_metadata_and_client( + provider_metadata: ProviderMetadata, + application_base_url: Uri, + client_id: String, + client_secret: Option, + scopes: Vec, + http_client: reqwest::Client, + ) -> Result { + let end_session_endpoint = provider_metadata + .additional_metadata() + .end_session_endpoint + .clone() + .map(Uri::from_maybe_shared) + .transpose() + .map_err(Error::InvalidEndSessionEndpoint)?; + let client = Client::from_provider_metadata( + provider_metadata, + ClientId::new(client_id.clone()), + client_secret.map(ClientSecret::new), + ); + Ok(Self { + scopes, + client, + client_id, + application_base_url, + end_session_endpoint, + http_client, }) } @@ -162,6 +194,7 @@ impl OidcClient { client_id: String, client_secret: Option, scopes: Vec, + //TODO remove borrow with next breaking version client: &reqwest::Client, ) -> Result { // modified version of `openidconnect::reqwest::async_client::async_http_client`. @@ -196,12 +229,13 @@ impl OidcClient { let provider_metadata = ProviderMetadata::discover_async(IssuerUrl::new(issuer)?, async_http_client).await?; - Self::from_provider_metadata( + Self::from_provider_metadata_and_client( provider_metadata, application_base_url, client_id, client_secret, scopes, + client.clone(), ) } } diff --git a/src/middleware.rs b/src/middleware.rs index 8f0432a..3ae8437 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -1,5 +1,6 @@ use std::{ marker::PhantomData, + pin::Pin, task::{Context, Poll}, }; @@ -8,7 +9,7 @@ use axum::{ response::{IntoResponse, Redirect}, }; use axum_core::{extract::FromRequestParts, response::Response}; -use futures_util::future::BoxFuture; +use futures_util::{future::BoxFuture, Future}; use http::{request::Parts, uri::PathAndQuery, Request, Uri}; use tower_layer::Layer; use tower_service::Service; @@ -16,9 +17,9 @@ use tower_sessions::Session; use openidconnect::{ core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim}, - reqwest::async_http_client, - AccessToken, AccessTokenHash, AuthorizationCode, CsrfToken, IdTokenClaims, Nonce, - OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, + AccessToken, AccessTokenHash, AuthorizationCode, CsrfToken, HttpRequest, HttpResponse, + IdTokenClaims, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, + RefreshToken, RequestTokenError::ServerResponse, Scope, TokenResponse, }; @@ -149,7 +150,7 @@ where .set_pkce_verifier(PkceCodeVerifier::new( login_session.pkce_verifier.secret().to_string(), )) - .request_async(async_http_client) + .request_async(async_http_client(&oidcclient.http_client)) .await?; // Extract the ID token claims after verifying its authenticity and nonce. @@ -409,16 +410,17 @@ fn insert_extensions( parts.extensions.insert(OidcAccessToken( authenticated_session.access_token.secret().to_string(), )); - let rp_initiated_logout = client.end_session_endpoint.as_ref().map(|end_session_endpoint| -OidcRpInitiatedLogout { + let rp_initiated_logout = client + .end_session_endpoint + .as_ref() + .map(|end_session_endpoint| OidcRpInitiatedLogout { end_session_endpoint: end_session_endpoint.clone(), id_token_hint: authenticated_session.id_token.to_string(), client_id: client.client_id.clone(), post_logout_redirect_uri: None, state: None, - } - ); - parts.extensions.insert(rp_initiated_logout); + }); + parts.extensions.insert(rp_initiated_logout); } /// Verify the access token hash to ensure that the access token hasn't been substituted for @@ -460,7 +462,10 @@ async fn try_refresh_token( refresh_request = refresh_request.add_scope(Scope::new(scope.to_string())); } - match refresh_request.request_async(async_http_client).await { + match refresh_request + .request_async(async_http_client(&client.http_client)) + .await + { Ok(token_response) => { // Extract the ID token claims after verifying its authenticity and nonce. let id_token = token_response @@ -489,3 +494,47 @@ async fn try_refresh_token( Err(err) => Err(err.into()), } } + +/// `openidconnect::reqwest::async_http_client` that uses a custom `reqwest::client` +fn async_http_client<'a>( + client: &'a reqwest::Client, +) -> impl FnOnce( + HttpRequest, +) -> Pin< + Box< + dyn Future>> + + Send + + 'a, + >, +> { + move |request: HttpRequest| { + Box::pin(async move { + let mut request_builder = client + .request(request.method, request.url.as_str()) + .body(request.body); + for (name, value) in &request.headers { + request_builder = request_builder.header(name.as_str(), value.as_bytes()); + } + let request = request_builder + .build() + .map_err(openidconnect::reqwest::Error::Reqwest)?; + + let response = client + .execute(request) + .await + .map_err(openidconnect::reqwest::Error::Reqwest)?; + + let status_code = response.status(); + let headers = response.headers().to_owned(); + let chunks = response + .bytes() + .await + .map_err(openidconnect::reqwest::Error::Reqwest)?; + Ok(HttpResponse { + status_code, + headers, + body: chunks.to_vec(), + }) + }) + } +}