From 5952cbff95489ba038ffff18cf1a0d1dd54e71f5 Mon Sep 17 00:00:00 2001 From: JuliDi <20155974+JuliDi@users.noreply.github.com> Date: Thu, 15 May 2025 18:36:44 +0200 Subject: [PATCH] add UserInfoClaims add allow additional audiences add tracing update basic example apply clippy lints --- Cargo.toml | 1 + examples/basic/Cargo.toml | 23 ++++++++---- examples/basic/src/lib.rs | 51 ++++++++++++++++++------- examples/basic/src/main.rs | 6 +++ src/error.rs | 25 ++++++++----- src/extractor.rs | 56 ++++++++++++++++++++++++++- src/handler.rs | 34 ++++++++++++++--- src/lib.rs | 5 ++- src/middleware.rs | 77 ++++++++++++++++++++++++++++++++------ 9 files changed, 226 insertions(+), 52 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 52e1353..be78178 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,3 +24,4 @@ serde = "1.0" futures-util = "0.3" reqwest = { version = "0.12", default-features = false } urlencoding = "2.1" +tracing = "0.1.41" diff --git a/examples/basic/Cargo.toml b/examples/basic/Cargo.toml index bf2562f..8322c85 100644 --- a/examples/basic/Cargo.toml +++ b/examples/basic/Cargo.toml @@ -1,25 +1,32 @@ [package] +edition = "2021" name = "basic" version = "0.1.0" -edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -tokio = { version = "1.43", features = ["net", "macros", "rt-multi-thread"] } -axum = { version = "0.8", features = [ "macros" ]} +axum = { version = "0.8", features = ["macros"] } axum-oidc = { path = "./../.." } +dotenvy = "0.15" +tokio = { version = "1.43", features = ["macros", "net", "rt-multi-thread"] } tower = "0.5" tower-sessions = "0.14" -dotenvy = "0.15" +openidconnect = "4.0" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0.140" +tracing-subscriber = "0.3.19" +tracing = "0.1.41" [dev-dependencies] +env_logger = "0.11" +headless_chrome = "1.0" +log = "0.4" +reqwest = { version = "0.12", features = [ + "rustls-tls", +], default-features = false } testcontainers = "0.23" tokio = { version = "1.43", features = ["rt-multi-thread"] } -reqwest = { version = "0.12", features = ["rustls-tls"], default-features = false } -env_logger = "0.11" -log = "0.4" -headless_chrome = "1.0" #see https://github.com/rust-headless-chrome/rust-headless-chrome/issues/535 auto_generate_cdp = "=0.4.4" diff --git a/examples/basic/src/lib.rs b/examples/basic/src/lib.rs index 96593ee..0399551 100644 --- a/examples/basic/src/lib.rs +++ b/examples/basic/src/lib.rs @@ -6,9 +6,10 @@ use axum::{ Router, }; use axum_oidc::{ - error::MiddlewareError, handle_oidc_redirect, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, - OidcClient, OidcLoginLayer, OidcRpInitiatedLogout, + error::MiddlewareError, handle_oidc_redirect, AdditionalClaims, Audience, Config, + OidcAuthLayer, OidcClaims, OidcClient, OidcLoginLayer, OidcRpInitiatedLogout, OidcUserClaims, }; +use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; use tower::ServiceBuilder; use tower_sessions::{ @@ -16,6 +17,15 @@ use tower_sessions::{ Expiry, MemoryStore, SessionManagerLayer, }; +#[derive(Clone, Debug, Serialize, Deserialize)] +//struct MyAdditionalClaims(HashMap); +struct MyAdditionalClaims { + admin: Option, +} + +impl AdditionalClaims for MyAdditionalClaims {} +impl openidconnect::AdditionalClaims for MyAdditionalClaims {} + pub async fn run(issuer: String, client_id: String, client_secret: Option) { let session_store = MemoryStore::default(); let session_layer = SessionManagerLayer::new(session_store) @@ -28,30 +38,42 @@ pub async fn run(issuer: String, client_id: String, client_secret: Option::new()); + .layer(OidcLoginLayer::::new()); - let mut oidc_client = OidcClient::::builder() + let mut oidc_client = OidcClient::::builder() .with_default_http_client() - .with_redirect_url(Uri::from_static("http://localhost:8080/oidc")) - .with_client_id(client_id); + .with_redirect_url(Uri::from_static("http://127.0.0.1:8080/oidc")) + .with_client_id(client_id) + .add_scope("profile") + .add_scope("email") + .add_scope("urn:zitadel:iam:org:project:id:zitadel:aud"); if let Some(client_secret) = client_secret { oidc_client = oidc_client.with_client_secret(client_secret); } let oidc_client = oidc_client.discover(issuer).await.unwrap().build(); + let config = Config { + other_audiences: vec![ + Audience::new("318246545105453932".to_string()), + Audience::new("318244871846527852".to_string()), + Audience::new("317981086246456313".to_string()), + ], + }; + let oidc_auth_service = ServiceBuilder::new() .layer(HandleErrorLayer::new(|e: MiddlewareError| async { dbg!(&e); e.into_response() })) - .layer(OidcAuthLayer::new(oidc_client)); + .layer(OidcAuthLayer::new(oidc_client, config.clone())); let app = Router::new() .route("/foo", get(authenticated)) .route("/logout", get(logout)) .layer(oidc_login_service) .route("/bar", get(maybe_authenticated)) - .route("/oidc", any(handle_oidc_redirect::)) + .route("/oidc", any(handle_oidc_redirect::)) + .with_state(config) .layer(oidc_auth_service) .layer(session_layer); @@ -61,24 +83,25 @@ pub async fn run(issuer: String, client_id: String, client_secret: Option) -> impl IntoResponse { +async fn authenticated(claims: OidcClaims) -> impl IntoResponse { format!("Hello {}", claims.subject().as_str()) } #[axum::debug_handler] async fn maybe_authenticated( - claims: Result, axum_oidc::error::ExtractorError>, + claims: Result, axum_oidc::error::ExtractorError>, ) -> impl IntoResponse { if let Ok(claims) = claims { + dbg!(&claims); format!( - "Hello {}! You are already logged in from another Handler.", - claims.subject().as_str() + "Hello {:#?}! You are already logged in from another Handler.", + claims.name().unwrap().get(None).unwrap().as_str() ) } else { - "Hello anon!".to_string() + "Hello unauthenticated user!".to_string() } } async fn logout(logout: OidcRpInitiatedLogout) -> impl IntoResponse { - logout.with_post_logout_redirect(Uri::from_static("https://example.com")) + logout.with_post_logout_redirect(Uri::from_static("http://127.0.0.1:8080/bar")) } diff --git a/examples/basic/src/main.rs b/examples/basic/src/main.rs index 0456d55..82d9645 100644 --- a/examples/basic/src/main.rs +++ b/examples/basic/src/main.rs @@ -1,6 +1,12 @@ use basic::run; +use tracing::Level; #[tokio::main] async fn main() { + tracing_subscriber::fmt() + .with_file(true) + .with_line_number(true) + .with_max_level(Level::TRACE) + .init(); dotenvy::dotenv().ok(); let issuer = std::env::var("ISSUER").expect("ISSUER env variable"); let client_id = std::env::var("CLIENT_ID").expect("CLIENT_ID env variable"); diff --git a/src/error.rs b/src/error.rs index 1bd47d5..d55a990 100644 --- a/src/error.rs +++ b/src/error.rs @@ -41,6 +41,14 @@ pub enum MiddlewareError { #[error("claims verification: {0:?}")] ClaimsVerification(#[from] openidconnect::ClaimsVerificationError), + #[error("user info retrieval: {0:?}")] + UserInfoRetrieval( + #[from] + openidconnect::UserInfoError< + openidconnect::HttpClientError, + >, + ), + #[error("url parsing: {0:?}")] UrlParsing(#[from] openidconnect::url::ParseError), @@ -74,7 +82,7 @@ pub enum MiddlewareError { #[derive(Debug, Error)] pub enum HandlerError { - #[error("the redirect handler got accessed without a valid session")] + #[error("redirect handler accessed without valid session, session cookie missing?")] RedirectedWithoutSession, #[error("csrf token invalid")] @@ -153,24 +161,21 @@ impl IntoResponse for ExtractorError { impl IntoResponse for Error { fn into_response(self) -> axum_core::response::Response { - match self { - _ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(), - } + tracing::error!(error = self.to_string()); + (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response() } } impl IntoResponse for MiddlewareError { fn into_response(self) -> axum_core::response::Response { - match self { - _ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(), - } + tracing::error!(error = self.to_string()); + (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response() } } impl IntoResponse for HandlerError { fn into_response(self) -> axum_core::response::Response { - match self { - _ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(), - } + tracing::error!(error = self.to_string()); + (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response() } } diff --git a/src/extractor.rs b/src/extractor.rs index 8735ee7..4385a86 100644 --- a/src/extractor.rs +++ b/src/extractor.rs @@ -7,12 +7,12 @@ use axum_core::{ response::IntoResponse, }; use http::{request::Parts, uri::PathAndQuery, Uri}; -use openidconnect::{core::CoreGenderClaim, IdTokenClaims}; +use openidconnect::{core::CoreGenderClaim, IdTokenClaims, UserInfoClaims}; /// Extractor for the OpenID Connect Claims. /// /// This Extractor will only return the Claims when the cached session is valid and [`crate::middleware::OidcAuthMiddleware`] is loaded. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct OidcClaims(pub IdTokenClaims); impl FromRequestParts for OidcClaims @@ -213,3 +213,55 @@ impl IntoResponse for OidcRpInitiatedLogout { } } } + + +/// Extractor for the OpenID Connect User Info Claims. +/// +/// This Extractor will only return the User Info Claims when the cached session is valid and [`crate::middleware::OidcAuthMiddleware`] is loaded. +#[derive(Clone, Debug)] +pub struct OidcUserClaims(pub UserInfoClaims); + +impl FromRequestParts for OidcUserClaims +where + S: Send + Sync, + AC: AdditionalClaims, +{ + type Rejection = ExtractorError; + + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { + parts + .extensions + .get::() + .cloned() + .ok_or(ExtractorError::Unauthorized) + } +} + +impl OptionalFromRequestParts for OidcUserClaims +where + S: Send + Sync, + AC: AdditionalClaims, +{ + type Rejection = Infallible; + + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result, Self::Rejection> { + Ok(parts.extensions.get::().cloned()) + } +} + +impl Deref for OidcUserClaims { + type Target = UserInfoClaims; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl AsRef> for OidcUserClaims +where + AC: AdditionalClaims, +{ + fn as_ref(&self) -> &UserInfoClaims { + &self.0 + } +} diff --git a/src/handler.rs b/src/handler.rs index c3bbd95..9038a79 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,4 +1,8 @@ -use axum::{extract::Query, response::Redirect, Extension}; +use axum::{ + extract::{Query, State}, + response::Redirect, + Extension, +}; use openidconnect::{ core::{CoreGenderClaim, CoreJsonWebKey}, AccessToken, AccessTokenHash, AuthorizationCode, IdTokenClaims, IdTokenVerifier, @@ -8,8 +12,8 @@ use serde::Deserialize; use tower_sessions::Session; use crate::{ - error::HandlerError, AdditionalClaims, AuthenticatedSession, IdToken, OidcClient, OidcSession, - SESSION_KEY, + error::HandlerError, AdditionalClaims, AuthenticatedSession, Config, IdToken, OidcClient, + OidcSession, SESSION_KEY, }; /// response data of the openid issuer after login @@ -21,11 +25,16 @@ pub struct OidcQuery { session_state: Option, } +#[tracing::instrument(skip(oidcclient), err)] pub async fn handle_oidc_redirect( session: Session, Extension(oidcclient): Extension>, + State(config): State, Query(query): Query, ) -> Result { + + tracing::debug!("start handling oidc redirect"); + let mut login_session: OidcSession = session .get(SESSION_KEY) .await? @@ -33,10 +42,12 @@ pub async fn handle_oidc_redirect( // the request has the request headers of the oidc redirect // parse the headers and exchange the code for a valid token + tracing::debug!("validating scrf token"); if login_session.csrf_token.secret() != &query.state { return Err(HandlerError::CsrfTokenInvalid); } + tracing::debug!("obtain token response"); let token_response = oidcclient .client .exchange_code(AuthorizationCode::new(query.code.to_string()))? @@ -47,19 +58,27 @@ pub async fn handle_oidc_redirect( .request_async(&oidcclient.http_client) .await?; + tracing::debug!("extract claims and verify it"); // Extract the ID token claims after verifying its authenticity and nonce. let id_token = token_response .id_token() .ok_or(HandlerError::IdTokenMissing)?; - let id_token_verifier = oidcclient.client.id_token_verifier(); + let id_token_verifier = oidcclient + .client + .id_token_verifier() + .set_other_audience_verifier_fn(|audience| config.other_audiences.contains(audience)); let claims = id_token.claims(&id_token_verifier, &login_session.nonce)?; + tracing::debug!("validate access token hash"); validate_access_token_hash( id_token, id_token_verifier, token_response.access_token(), claims, - )?; + ) + .inspect_err(|e| tracing::error!(?e, "Access token hash invalid"))?; + + tracing::debug!("Access token hash validated"); login_session.authenticated = Some(AuthenticatedSession { id_token: id_token.clone(), @@ -70,6 +89,10 @@ pub async fn handle_oidc_redirect( login_session.refresh_token = Some(refresh_token); } + tracing::debug!( + "Inserting session and redirecting to {}", + &login_session.redirect_url + ); let redirect_url = login_session.redirect_url.clone(); session.insert(SESSION_KEY, login_session).await?; @@ -79,6 +102,7 @@ pub async fn handle_oidc_redirect( /// Verify the access token hash to ensure that the access token hasn't been substituted for /// another user's. /// Returns `Ok` when access token is valid +#[tracing::instrument(skip_all, err)] fn validate_access_token_hash( id_token: &IdToken, id_token_verifier: IdTokenVerifier, diff --git a/src/lib.rs b/src/lib.rs index dc22366..17872ef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,9 +24,10 @@ mod extractor; mod handler; mod middleware; -pub use extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}; +pub use extractor::{OidcAccessToken, OidcClaims, OidcUserClaims, OidcRpInitiatedLogout}; pub use handler::handle_oidc_redirect; -pub use middleware::{OidcAuthLayer, OidcAuthMiddleware, OidcLoginLayer, OidcLoginMiddleware}; +pub use middleware::{Config, OidcAuthLayer, OidcAuthMiddleware, OidcLoginLayer, OidcLoginMiddleware}; +pub use openidconnect::Audience; const SESSION_KEY: &str = "axum-oidc"; diff --git a/src/middleware.rs b/src/middleware.rs index 5eb14e6..fd24c00 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -13,19 +13,24 @@ use tower_sessions::Session; use openidconnect::{ core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey}, - AccessToken, AccessTokenHash, AuthenticationContextClass, CsrfToken, IdTokenClaims, + AccessToken, AccessTokenHash, Audience, AuthenticationContextClass, CsrfToken, IdTokenClaims, IdTokenVerifier, Nonce, OAuth2TokenResponse, PkceCodeChallenge, RefreshToken, RequestTokenError::ServerResponse, - Scope, TokenResponse, + Scope, TokenResponse, UserInfoClaims, }; use crate::{ error::MiddlewareError, - extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}, + extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout, OidcUserClaims}, AdditionalClaims, AuthenticatedSession, BoxError, ClearSessionFlag, IdToken, OidcClient, OidcSession, SESSION_KEY, }; +#[derive(Clone, Default, Debug)] +pub struct Config { + pub other_audiences: Vec, +} + /// Layer for the [`OidcLoginMiddleware`]. #[derive(Clone, Default)] pub struct OidcLoginLayer @@ -161,16 +166,17 @@ where AC: AdditionalClaims, { client: OidcClient, + config: Config, } impl OidcAuthLayer { - pub fn new(client: OidcClient) -> Self { - Self { client } + pub fn new(client: OidcClient, config: Config) -> Self { + Self { client, config } } } impl From> for OidcAuthLayer { fn from(value: OidcClient) -> Self { - Self::new(value) + Self::new(value, Config::default()) } } @@ -184,6 +190,7 @@ where OidcAuthMiddleware { inner, client: self.client.clone(), + config: self.config.clone(), } } } @@ -199,6 +206,7 @@ where { inner: I, client: OidcClient, + config: Config, } impl Service> for OidcAuthMiddleware @@ -223,7 +231,9 @@ where fn call(&mut self, request: Request) -> Self::Future { let inner = self.inner.clone(); let mut inner = std::mem::replace(&mut self.inner, inner); + let other_audiences = self.config.other_audiences.clone(); let oidcclient = self.client.clone(); + Box::pin(async move { let (mut parts, body) = request.into_parts(); @@ -241,21 +251,43 @@ where let id_token_claims = login_session.authenticated.as_ref().and_then(|session| { session .id_token - .claims(&oidcclient.client.id_token_verifier(), &login_session.nonce) + .claims( + &oidcclient + .client + .id_token_verifier() + .set_other_audience_verifier_fn(|audience| { + other_audiences.contains(audience) + }), + &login_session.nonce, + ) .ok() .cloned() .map(|claims| (session, claims)) }); if let Some((session, claims)) = id_token_claims { + let user_claims = + get_user_claims(&oidcclient, session.access_token.clone()).await?; // stored id token is valid and can be used - insert_extensions(&mut parts, claims.clone(), &oidcclient, session); + insert_extensions( + &mut parts, + claims.clone(), + user_claims, + &oidcclient, + session, + ); } else if let Some(refresh_token) = login_session.refresh_token.as_ref() { // session is expired but can be refreshed using the refresh_token - if let Some((claims, authenticated_session, refresh_token)) = + if let Some((claims, user_claims, authenticated_session, refresh_token)) = try_refresh_token(&oidcclient, refresh_token, &login_session.nonce).await? { - insert_extensions(&mut parts, claims, &oidcclient, &authenticated_session); + insert_extensions( + &mut parts, + claims, + user_claims.clone(), + &oidcclient, + &authenticated_session, + ); login_session.authenticated = Some(authenticated_session); if let Some(refresh_token) = refresh_token { @@ -297,10 +329,12 @@ where fn insert_extensions( parts: &mut Parts, claims: IdTokenClaims, + user_claims: UserInfoClaims, client: &OidcClient, authenticated_session: &AuthenticatedSession, ) { parts.extensions.insert(OidcClaims(claims)); + parts.extensions.insert(OidcUserClaims(user_claims)); parts.extensions.insert(OidcAccessToken( authenticated_session.access_token.secret().to_string(), )); @@ -342,6 +376,19 @@ fn validate_access_token_hash( } } +async fn get_user_claims( + client: &OidcClient, + access_token: AccessToken, +) -> Result, MiddlewareError> { + client + .client + .user_info(access_token, None) + .map_err(MiddlewareError::Configuration)? + .request_async(&client.http_client) + .await + .map_err(|e| e.into()) +} + async fn try_refresh_token( client: &OidcClient, refresh_token: &RefreshToken, @@ -349,6 +396,7 @@ async fn try_refresh_token( ) -> Result< Option<( IdTokenClaims, + UserInfoClaims, AuthenticatedSession, Option, )>, @@ -366,7 +414,10 @@ async fn try_refresh_token( let id_token = token_response .id_token() .ok_or(MiddlewareError::IdTokenMissing)?; - let id_token_verifier = client.client.id_token_verifier(); + let id_token_verifier = client + .client + .id_token_verifier() + .require_audience_match(false); let claims = id_token.claims(&id_token_verifier, nonce)?; validate_access_token_hash( @@ -381,8 +432,12 @@ async fn try_refresh_token( access_token: token_response.access_token().clone(), }; + let user_claims = + get_user_claims(client, authenticated_session.access_token.clone()).await?; + Ok(Some(( claims.clone(), + user_claims, authenticated_session, token_response.refresh_token().cloned(), )))