mirror of
https://github.com/pfzetto/axum-oidc.git
synced 2024-12-03 16:27:14 +01:00
fix: use custom reqwest::Client
in middleware
previously the middlewares would use the default `reqwest::Client` even if the `OidcClient` is created with a custom client. Now the middleware uses the `reqwest::Client` used during creation of `OidcClient`.
This commit is contained in:
parent
202b61fa83
commit
9dd85a7703
2 changed files with 95 additions and 12 deletions
36
src/lib.rs
36
src/lib.rs
|
@ -98,6 +98,7 @@ pub struct OidcClient<AC: AdditionalClaims> {
|
||||||
scopes: Vec<String>,
|
scopes: Vec<String>,
|
||||||
client_id: String,
|
client_id: String,
|
||||||
client: Client<AC>,
|
client: Client<AC>,
|
||||||
|
http_client: reqwest::Client,
|
||||||
application_base_url: Uri,
|
application_base_url: Uri,
|
||||||
end_session_endpoint: Option<Uri>,
|
end_session_endpoint: Option<Uri>,
|
||||||
}
|
}
|
||||||
|
@ -129,6 +130,37 @@ impl<AC: AdditionalClaims> OidcClient<AC> {
|
||||||
client_id,
|
client_id,
|
||||||
application_base_url,
|
application_base_url,
|
||||||
end_session_endpoint,
|
end_session_endpoint,
|
||||||
|
http_client: reqwest::Client::default(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
/// create a new [`OidcClient`] from an existing [`ProviderMetadata`].
|
||||||
|
pub fn from_provider_metadata_and_client(
|
||||||
|
provider_metadata: ProviderMetadata,
|
||||||
|
application_base_url: Uri,
|
||||||
|
client_id: String,
|
||||||
|
client_secret: Option<String>,
|
||||||
|
scopes: Vec<String>,
|
||||||
|
http_client: reqwest::Client,
|
||||||
|
) -> Result<Self, Error> {
|
||||||
|
let end_session_endpoint = provider_metadata
|
||||||
|
.additional_metadata()
|
||||||
|
.end_session_endpoint
|
||||||
|
.clone()
|
||||||
|
.map(Uri::from_maybe_shared)
|
||||||
|
.transpose()
|
||||||
|
.map_err(Error::InvalidEndSessionEndpoint)?;
|
||||||
|
let client = Client::from_provider_metadata(
|
||||||
|
provider_metadata,
|
||||||
|
ClientId::new(client_id.clone()),
|
||||||
|
client_secret.map(ClientSecret::new),
|
||||||
|
);
|
||||||
|
Ok(Self {
|
||||||
|
scopes,
|
||||||
|
client,
|
||||||
|
client_id,
|
||||||
|
application_base_url,
|
||||||
|
end_session_endpoint,
|
||||||
|
http_client,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -162,6 +194,7 @@ impl<AC: AdditionalClaims> OidcClient<AC> {
|
||||||
client_id: String,
|
client_id: String,
|
||||||
client_secret: Option<String>,
|
client_secret: Option<String>,
|
||||||
scopes: Vec<String>,
|
scopes: Vec<String>,
|
||||||
|
//TODO remove borrow with next breaking version
|
||||||
client: &reqwest::Client,
|
client: &reqwest::Client,
|
||||||
) -> Result<Self, Error> {
|
) -> Result<Self, Error> {
|
||||||
// modified version of `openidconnect::reqwest::async_client::async_http_client`.
|
// modified version of `openidconnect::reqwest::async_client::async_http_client`.
|
||||||
|
@ -196,12 +229,13 @@ impl<AC: AdditionalClaims> OidcClient<AC> {
|
||||||
|
|
||||||
let provider_metadata =
|
let provider_metadata =
|
||||||
ProviderMetadata::discover_async(IssuerUrl::new(issuer)?, async_http_client).await?;
|
ProviderMetadata::discover_async(IssuerUrl::new(issuer)?, async_http_client).await?;
|
||||||
Self::from_provider_metadata(
|
Self::from_provider_metadata_and_client(
|
||||||
provider_metadata,
|
provider_metadata,
|
||||||
application_base_url,
|
application_base_url,
|
||||||
client_id,
|
client_id,
|
||||||
client_secret,
|
client_secret,
|
||||||
scopes,
|
scopes,
|
||||||
|
client.clone(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
use std::{
|
use std::{
|
||||||
marker::PhantomData,
|
marker::PhantomData,
|
||||||
|
pin::Pin,
|
||||||
task::{Context, Poll},
|
task::{Context, Poll},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -8,7 +9,7 @@ use axum::{
|
||||||
response::{IntoResponse, Redirect},
|
response::{IntoResponse, Redirect},
|
||||||
};
|
};
|
||||||
use axum_core::{extract::FromRequestParts, response::Response};
|
use axum_core::{extract::FromRequestParts, response::Response};
|
||||||
use futures_util::future::BoxFuture;
|
use futures_util::{future::BoxFuture, Future};
|
||||||
use http::{request::Parts, uri::PathAndQuery, Request, Uri};
|
use http::{request::Parts, uri::PathAndQuery, Request, Uri};
|
||||||
use tower_layer::Layer;
|
use tower_layer::Layer;
|
||||||
use tower_service::Service;
|
use tower_service::Service;
|
||||||
|
@ -16,9 +17,9 @@ use tower_sessions::Session;
|
||||||
|
|
||||||
use openidconnect::{
|
use openidconnect::{
|
||||||
core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim},
|
core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim},
|
||||||
reqwest::async_http_client,
|
AccessToken, AccessTokenHash, AuthorizationCode, CsrfToken, HttpRequest, HttpResponse,
|
||||||
AccessToken, AccessTokenHash, AuthorizationCode, CsrfToken, IdTokenClaims, Nonce,
|
IdTokenClaims, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl,
|
||||||
OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken,
|
RefreshToken,
|
||||||
RequestTokenError::ServerResponse,
|
RequestTokenError::ServerResponse,
|
||||||
Scope, TokenResponse,
|
Scope, TokenResponse,
|
||||||
};
|
};
|
||||||
|
@ -149,7 +150,7 @@ where
|
||||||
.set_pkce_verifier(PkceCodeVerifier::new(
|
.set_pkce_verifier(PkceCodeVerifier::new(
|
||||||
login_session.pkce_verifier.secret().to_string(),
|
login_session.pkce_verifier.secret().to_string(),
|
||||||
))
|
))
|
||||||
.request_async(async_http_client)
|
.request_async(async_http_client(&oidcclient.http_client))
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Extract the ID token claims after verifying its authenticity and nonce.
|
// Extract the ID token claims after verifying its authenticity and nonce.
|
||||||
|
@ -409,16 +410,17 @@ fn insert_extensions<AC: AdditionalClaims>(
|
||||||
parts.extensions.insert(OidcAccessToken(
|
parts.extensions.insert(OidcAccessToken(
|
||||||
authenticated_session.access_token.secret().to_string(),
|
authenticated_session.access_token.secret().to_string(),
|
||||||
));
|
));
|
||||||
let rp_initiated_logout = client.end_session_endpoint.as_ref().map(|end_session_endpoint|
|
let rp_initiated_logout = client
|
||||||
OidcRpInitiatedLogout {
|
.end_session_endpoint
|
||||||
|
.as_ref()
|
||||||
|
.map(|end_session_endpoint| OidcRpInitiatedLogout {
|
||||||
end_session_endpoint: end_session_endpoint.clone(),
|
end_session_endpoint: end_session_endpoint.clone(),
|
||||||
id_token_hint: authenticated_session.id_token.to_string(),
|
id_token_hint: authenticated_session.id_token.to_string(),
|
||||||
client_id: client.client_id.clone(),
|
client_id: client.client_id.clone(),
|
||||||
post_logout_redirect_uri: None,
|
post_logout_redirect_uri: None,
|
||||||
state: None,
|
state: None,
|
||||||
}
|
});
|
||||||
);
|
parts.extensions.insert(rp_initiated_logout);
|
||||||
parts.extensions.insert(rp_initiated_logout);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Verify the access token hash to ensure that the access token hasn't been substituted for
|
/// Verify the access token hash to ensure that the access token hasn't been substituted for
|
||||||
|
@ -460,7 +462,10 @@ async fn try_refresh_token<AC: AdditionalClaims>(
|
||||||
refresh_request = refresh_request.add_scope(Scope::new(scope.to_string()));
|
refresh_request = refresh_request.add_scope(Scope::new(scope.to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
match refresh_request.request_async(async_http_client).await {
|
match refresh_request
|
||||||
|
.request_async(async_http_client(&client.http_client))
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(token_response) => {
|
Ok(token_response) => {
|
||||||
// Extract the ID token claims after verifying its authenticity and nonce.
|
// Extract the ID token claims after verifying its authenticity and nonce.
|
||||||
let id_token = token_response
|
let id_token = token_response
|
||||||
|
@ -489,3 +494,47 @@ async fn try_refresh_token<AC: AdditionalClaims>(
|
||||||
Err(err) => Err(err.into()),
|
Err(err) => Err(err.into()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// `openidconnect::reqwest::async_http_client` that uses a custom `reqwest::client`
|
||||||
|
fn async_http_client<'a>(
|
||||||
|
client: &'a reqwest::Client,
|
||||||
|
) -> impl FnOnce(
|
||||||
|
HttpRequest,
|
||||||
|
) -> Pin<
|
||||||
|
Box<
|
||||||
|
dyn Future<Output = Result<HttpResponse, openidconnect::reqwest::Error<reqwest::Error>>>
|
||||||
|
+ Send
|
||||||
|
+ 'a,
|
||||||
|
>,
|
||||||
|
> {
|
||||||
|
move |request: HttpRequest| {
|
||||||
|
Box::pin(async move {
|
||||||
|
let mut request_builder = client
|
||||||
|
.request(request.method, request.url.as_str())
|
||||||
|
.body(request.body);
|
||||||
|
for (name, value) in &request.headers {
|
||||||
|
request_builder = request_builder.header(name.as_str(), value.as_bytes());
|
||||||
|
}
|
||||||
|
let request = request_builder
|
||||||
|
.build()
|
||||||
|
.map_err(openidconnect::reqwest::Error::Reqwest)?;
|
||||||
|
|
||||||
|
let response = client
|
||||||
|
.execute(request)
|
||||||
|
.await
|
||||||
|
.map_err(openidconnect::reqwest::Error::Reqwest)?;
|
||||||
|
|
||||||
|
let status_code = response.status();
|
||||||
|
let headers = response.headers().to_owned();
|
||||||
|
let chunks = response
|
||||||
|
.bytes()
|
||||||
|
.await
|
||||||
|
.map_err(openidconnect::reqwest::Error::Reqwest)?;
|
||||||
|
Ok(HttpResponse {
|
||||||
|
status_code,
|
||||||
|
headers,
|
||||||
|
body: chunks.to_vec(),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue