From 2800b88b82b51a3d42df6b5abf1c3ca3d827e694 Mon Sep 17 00:00:00 2001 From: MATILLAT Quentin Date: Sat, 25 Jan 2025 21:30:16 +0100 Subject: [PATCH 1/2] chore(deps): Update to openidconnect 0.4 Signed-off-by: MATILLAT Quentin --- Cargo.toml | 4 +- examples/basic/Cargo.toml | 2 +- src/error.rs | 16 +++++-- src/lib.rs | 70 +++++++++-------------------- src/middleware.rs | 94 +++++++++++++-------------------------- 5 files changed, 68 insertions(+), 118 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5cd18f0..52e1353 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,8 +19,8 @@ tower-service = "0.3" tower-layer = "0.3" tower-sessions = { version = "0.14", default-features = false, features = [ "axum-core" ] } http = "1.2" -openidconnect = "3.5" +openidconnect = "4.0" serde = "1.0" futures-util = "0.3" -reqwest = { version = "0.11", default-features = false } +reqwest = { version = "0.12", default-features = false } urlencoding = "2.1" diff --git a/examples/basic/Cargo.toml b/examples/basic/Cargo.toml index fe9028c..c5dcf7f 100644 --- a/examples/basic/Cargo.toml +++ b/examples/basic/Cargo.toml @@ -17,7 +17,7 @@ dotenvy = "0.15" [dev-dependencies] testcontainers = "0.23" tokio = { version = "1.43", features = ["rt-multi-thread"] } -reqwest = { version = "0.11", features = ["rustls-tls"], default-features = false } +reqwest = { version = "0.12", features = ["rustls-tls"], default-features = false } env_logger = "0.11" log = "0.4" headless_chrome = "1.0" diff --git a/src/error.rs b/src/error.rs index 454dddc..48f0903 100644 --- a/src/error.rs +++ b/src/error.rs @@ -16,11 +16,13 @@ pub enum ExtractorError { #[error("could not build rp initiated logout uri")] FailedToCreateRpInitiatedLogoutUri, - } #[derive(Debug, Error)] pub enum MiddlewareError { + #[error("configuration: {0:?}")] + Configuration(#[from] openidconnect::ConfigurationError), + #[error("access token hash invalid")] AccessTokenHashInvalid, @@ -33,6 +35,9 @@ pub enum MiddlewareError { #[error("signing: {0:?}")] Signing(#[from] openidconnect::SigningError), + #[error("signature verification: {0:?}")] + Signature(#[from] openidconnect::SignatureVerificationError), + #[error("claims verification: {0:?}")] ClaimsVerification(#[from] openidconnect::ClaimsVerificationError), @@ -49,7 +54,7 @@ pub enum MiddlewareError { RequestToken( #[from] openidconnect::RequestTokenError< - openidconnect::reqwest::Error, + openidconnect::HttpClientError, StandardErrorResponse, >, ), @@ -76,7 +81,12 @@ pub enum Error { InvalidEndSessionEndpoint(http::uri::InvalidUri), #[error("discovery: {0:?}")] - Discovery(#[from] openidconnect::DiscoveryError>), + Discovery( + #[from] + openidconnect::DiscoveryError< + openidconnect::HttpClientError, + >, + ), #[error("extractor: {0:?}")] Extractor(#[from] ExtractorError), diff --git a/src/lib.rs b/src/lib.rs index a6825b3..94ed56a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,14 +8,13 @@ use http::Uri; use openidconnect::{ core::{ CoreAuthDisplay, CoreAuthPrompt, CoreClaimName, CoreClaimType, CoreClientAuthMethod, - CoreErrorResponseType, CoreGenderClaim, CoreGrantType, CoreJsonWebKey, CoreJsonWebKeyType, - CoreJsonWebKeyUse, CoreJweContentEncryptionAlgorithm, CoreJweKeyManagementAlgorithm, - CoreJwsSigningAlgorithm, CoreResponseMode, CoreResponseType, CoreRevocableToken, - CoreRevocationErrorResponse, CoreSubjectIdentifierType, CoreTokenIntrospectionResponse, - CoreTokenType, + CoreErrorResponseType, CoreGenderClaim, CoreGrantType, CoreJsonWebKey, + CoreJweContentEncryptionAlgorithm, CoreJweKeyManagementAlgorithm, CoreJwsSigningAlgorithm, + CoreResponseMode, CoreResponseType, CoreRevocableToken, CoreRevocationErrorResponse, + CoreSubjectIdentifierType, CoreTokenIntrospectionResponse, CoreTokenType, }, - AccessToken, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, HttpRequest, - HttpResponse, IdTokenFields, IssuerUrl, Nonce, PkceCodeVerifier, RefreshToken, + AccessToken, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, EndpointMaybeSet, + EndpointNotSet, EndpointSet, IdTokenFields, IssuerUrl, Nonce, PkceCodeVerifier, RefreshToken, StandardErrorResponse, StandardTokenResponse, }; use serde::{de::DeserializeOwned, Deserialize, Serialize}; @@ -41,7 +40,6 @@ type OidcTokenResponse = StandardTokenResponse< CoreGenderClaim, CoreJweContentEncryptionAlgorithm, CoreJwsSigningAlgorithm, - CoreJsonWebKeyType, >, CoreTokenType, >; @@ -51,25 +49,34 @@ pub type IdToken = openidconnect::IdToken< CoreGenderClaim, CoreJweContentEncryptionAlgorithm, CoreJwsSigningAlgorithm, - CoreJsonWebKeyType, >; -type Client = openidconnect::Client< +type Client< + AC, + HasAuthUrl = EndpointSet, + HasDeviceAuthUrl = EndpointNotSet, + HasIntrospectionUrl = EndpointNotSet, + HasRevocationUrl = EndpointNotSet, + HasTokenUrl = EndpointMaybeSet, + HasUserInfoUrl = EndpointMaybeSet, +> = openidconnect::Client< AC, CoreAuthDisplay, CoreGenderClaim, CoreJweContentEncryptionAlgorithm, - CoreJwsSigningAlgorithm, - CoreJsonWebKeyType, - CoreJsonWebKeyUse, CoreJsonWebKey, CoreAuthPrompt, StandardErrorResponse, OidcTokenResponse, - CoreTokenType, CoreTokenIntrospectionResponse, CoreRevocableToken, CoreRevocationErrorResponse, + HasAuthUrl, + HasDeviceAuthUrl, + HasIntrospectionUrl, + HasRevocationUrl, + HasTokenUrl, + HasUserInfoUrl, >; pub type ProviderMetadata = openidconnect::ProviderMetadata< @@ -81,9 +88,6 @@ pub type ProviderMetadata = openidconnect::ProviderMetadata< CoreGrantType, CoreJweContentEncryptionAlgorithm, CoreJweKeyManagementAlgorithm, - CoreJwsSigningAlgorithm, - CoreJsonWebKeyType, - CoreJsonWebKeyUse, CoreJsonWebKey, CoreResponseMode, CoreResponseType, @@ -197,38 +201,8 @@ impl OidcClient { //TODO remove borrow with next breaking version client: &reqwest::Client, ) -> Result { - // modified version of `openidconnect::reqwest::async_client::async_http_client`. - let async_http_client = |request: HttpRequest| async move { - let mut request_builder = client - .request(request.method, request.url.as_str()) - .body(request.body); - for (name, value) in &request.headers { - request_builder = request_builder.header(name.as_str(), value.as_bytes()); - } - let request = request_builder - .build() - .map_err(openidconnect::reqwest::Error::Reqwest)?; - - let response = client - .execute(request) - .await - .map_err(openidconnect::reqwest::Error::Reqwest)?; - - let status_code = response.status(); - let headers = response.headers().to_owned(); - let chunks = response - .bytes() - .await - .map_err(openidconnect::reqwest::Error::Reqwest)?; - Ok(HttpResponse { - status_code, - headers, - body: chunks.to_vec(), - }) - }; - let provider_metadata = - ProviderMetadata::discover_async(IssuerUrl::new(issuer)?, async_http_client).await?; + ProviderMetadata::discover_async(IssuerUrl::new(issuer)?, client).await?; Self::from_provider_metadata_and_client( provider_metadata, application_base_url, diff --git a/src/middleware.rs b/src/middleware.rs index 3ae8437..5538108 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -1,6 +1,5 @@ use std::{ marker::PhantomData, - pin::Pin, task::{Context, Poll}, }; @@ -9,17 +8,16 @@ use axum::{ response::{IntoResponse, Redirect}, }; use axum_core::{extract::FromRequestParts, response::Response}; -use futures_util::{future::BoxFuture, Future}; +use futures_util::future::BoxFuture; use http::{request::Parts, uri::PathAndQuery, Request, Uri}; use tower_layer::Layer; use tower_service::Service; use tower_sessions::Session; use openidconnect::{ - core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim}, - AccessToken, AccessTokenHash, AuthorizationCode, CsrfToken, HttpRequest, HttpResponse, - IdTokenClaims, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, - RefreshToken, + core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey}, + AccessToken, AccessTokenHash, AuthorizationCode, CsrfToken, IdTokenClaims, IdTokenVerifier, + Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError::ServerResponse, Scope, TokenResponse, }; @@ -145,22 +143,27 @@ where let token_response = oidcclient .client - .exchange_code(AuthorizationCode::new(query.code.to_string())) + .exchange_code(AuthorizationCode::new(query.code.to_string()))? // Set the PKCE code verifier. .set_pkce_verifier(PkceCodeVerifier::new( login_session.pkce_verifier.secret().to_string(), )) - .request_async(async_http_client(&oidcclient.http_client)) + .request_async(&oidcclient.http_client) .await?; // Extract the ID token claims after verifying its authenticity and nonce. let id_token = token_response .id_token() .ok_or(MiddlewareError::IdTokenMissing)?; - let claims = id_token - .claims(&oidcclient.client.id_token_verifier(), &login_session.nonce)?; + let id_token_verifier = oidcclient.client.id_token_verifier(); + let claims = id_token.claims(&id_token_verifier, &login_session.nonce)?; - validate_access_token_hash(id_token, token_response.access_token(), claims)?; + validate_access_token_hash( + id_token, + id_token_verifier, + token_response.access_token(), + claims, + )?; login_session.authenticated = Some(AuthenticatedSession { id_token: id_token.clone(), @@ -428,12 +431,16 @@ fn insert_extensions( /// Returns `Ok` when access token is valid fn validate_access_token_hash( id_token: &IdToken, + id_token_verifier: IdTokenVerifier, access_token: &AccessToken, claims: &IdTokenClaims, ) -> Result<(), MiddlewareError> { if let Some(expected_access_token_hash) = claims.access_token_hash() { - let actual_access_token_hash = - AccessTokenHash::from_token(access_token, &id_token.signing_alg()?)?; + let actual_access_token_hash = AccessTokenHash::from_token( + access_token, + id_token.signing_alg()?, + id_token.signing_key(&id_token_verifier)?, + )?; if actual_access_token_hash == *expected_access_token_hash { Ok(()) } else { @@ -456,24 +463,27 @@ async fn try_refresh_token( )>, MiddlewareError, > { - let mut refresh_request = client.client.exchange_refresh_token(refresh_token); + let mut refresh_request = client.client.exchange_refresh_token(refresh_token)?; for scope in client.scopes.iter() { refresh_request = refresh_request.add_scope(Scope::new(scope.to_string())); } - match refresh_request - .request_async(async_http_client(&client.http_client)) - .await - { + match refresh_request.request_async(&client.http_client).await { Ok(token_response) => { // Extract the ID token claims after verifying its authenticity and nonce. let id_token = token_response .id_token() .ok_or(MiddlewareError::IdTokenMissing)?; - let claims = id_token.claims(&client.client.id_token_verifier(), nonce)?; + let id_token_verifier = client.client.id_token_verifier(); + let claims = id_token.claims(&id_token_verifier, nonce)?; - validate_access_token_hash(id_token, token_response.access_token(), claims)?; + validate_access_token_hash( + id_token, + id_token_verifier, + token_response.access_token(), + claims, + )?; let authenticated_session = AuthenticatedSession { id_token: id_token.clone(), @@ -494,47 +504,3 @@ async fn try_refresh_token( Err(err) => Err(err.into()), } } - -/// `openidconnect::reqwest::async_http_client` that uses a custom `reqwest::client` -fn async_http_client<'a>( - client: &'a reqwest::Client, -) -> impl FnOnce( - HttpRequest, -) -> Pin< - Box< - dyn Future>> - + Send - + 'a, - >, -> { - move |request: HttpRequest| { - Box::pin(async move { - let mut request_builder = client - .request(request.method, request.url.as_str()) - .body(request.body); - for (name, value) in &request.headers { - request_builder = request_builder.header(name.as_str(), value.as_bytes()); - } - let request = request_builder - .build() - .map_err(openidconnect::reqwest::Error::Reqwest)?; - - let response = client - .execute(request) - .await - .map_err(openidconnect::reqwest::Error::Reqwest)?; - - let status_code = response.status(); - let headers = response.headers().to_owned(); - let chunks = response - .bytes() - .await - .map_err(openidconnect::reqwest::Error::Reqwest)?; - Ok(HttpResponse { - status_code, - headers, - body: chunks.to_vec(), - }) - }) - } -} From 10349c61b51cce0bad43f044299c63061614c5c8 Mon Sep 17 00:00:00 2001 From: MATILLAT Quentin Date: Wed, 29 Jan 2025 15:03:53 +0100 Subject: [PATCH 2/2] chore!: Remove ref from http_client in constructors Signed-off-by: MATILLAT Quentin --- src/lib.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 94ed56a..bf3c15e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -184,7 +184,7 @@ impl OidcClient { client_id, client_secret, scopes, - &client, + client, ) .await } @@ -198,18 +198,17 @@ impl OidcClient { client_id: String, client_secret: Option, scopes: Vec, - //TODO remove borrow with next breaking version - client: &reqwest::Client, + client: reqwest::Client, ) -> Result { let provider_metadata = - ProviderMetadata::discover_async(IssuerUrl::new(issuer)?, client).await?; + ProviderMetadata::discover_async(IssuerUrl::new(issuer)?, &client).await?; Self::from_provider_metadata_and_client( provider_metadata, application_base_url, client_id, client_secret, scopes, - client.clone(), + client, ) } }