diff --git a/examples/basic/src/lib.rs b/examples/basic/src/lib.rs index df77766..0e8ad2c 100644 --- a/examples/basic/src/lib.rs +++ b/examples/basic/src/lib.rs @@ -41,6 +41,7 @@ pub async fn run( client_id, client_secret, vec![], + None, ) .await .unwrap(), diff --git a/src/lib.rs b/src/lib.rs index ae32f57..156d8b1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -100,6 +100,7 @@ pub struct OidcClient { client: Client, application_base_url: Uri, end_session_endpoint: Option, + acr: Option, } impl OidcClient { @@ -110,6 +111,7 @@ impl OidcClient { client_id: String, client_secret: Option, scopes: Vec, + acr: Option, ) -> Result { let end_session_endpoint = provider_metadata .additional_metadata() @@ -129,6 +131,7 @@ impl OidcClient { client_id, application_base_url, end_session_endpoint, + acr, }) } @@ -140,6 +143,7 @@ impl OidcClient { client_id: String, client_secret: Option, scopes: Vec, + acr: Option, ) -> Result { let client = reqwest::Client::default(); Self::discover_new_with_client( @@ -149,6 +153,7 @@ impl OidcClient { client_secret, scopes, &client, + acr, ) .await } @@ -163,6 +168,7 @@ impl OidcClient { client_secret: Option, scopes: Vec, client: &reqwest::Client, + acr: Option, ) -> Result { // modified version of `openidconnect::reqwest::async_client::async_http_client`. let async_http_client = |request: HttpRequest| async move { @@ -202,6 +208,7 @@ impl OidcClient { client_id, client_secret, scopes, + acr, ) } } diff --git a/src/middleware.rs b/src/middleware.rs index 8f0432a..c0f0335 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -17,8 +17,9 @@ use tower_sessions::Session; use openidconnect::{ core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim}, reqwest::async_http_client, - AccessToken, AccessTokenHash, AuthorizationCode, CsrfToken, IdTokenClaims, Nonce, - OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, + AccessToken, AccessTokenHash, AuthenticationContextClass, AuthorizationCode, CsrfToken, + IdTokenClaims, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, + RefreshToken, RequestTokenError::ServerResponse, Scope, TokenResponse, }; @@ -187,6 +188,10 @@ where for scope in oidcclient.scopes.iter() { auth = auth.add_scope(Scope::new(scope.to_string())); } + if let Some(acr) = oidcclient.acr { + auth = + auth.add_auth_context_value(AuthenticationContextClass::new(acr)); + } auth.set_pkce_challenge(pkce_challenge).url() }; @@ -228,6 +233,7 @@ impl OidcAuthLayer { client_id: String, client_secret: Option, scopes: Vec, + acr: Option, ) -> Result { Ok(Self { client: OidcClient::::discover_new( @@ -236,6 +242,7 @@ impl OidcAuthLayer { client_id, client_secret, scopes, + acr, ) .await?, })