diff --git a/Cargo.toml b/Cargo.toml index 84d25d3..ba70915 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,24 +3,30 @@ name = "axum-oidc" description = "A wrapper for the openidconnect crate for axum" version = "0.6.0" edition = "2021" -authors = [ "Paul Z " ] +authors = ["Paul Z "] readme = "README.md" repository = "https://github.com/pfz4/axum-oidc" license = "MPL-2.0" -keywords = [ "axum", "oidc", "openidconnect", "authentication" ] +keywords = ["axum", "oidc", "openidconnect", "authentication"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] thiserror = "2.0" axum-core = "0.5" -axum = { version = "0.8", default-features = false, features = [ "query", "original-uri" ] } +axum = { version = "0.8", default-features = false, features = [ + "query", + "original-uri", +] } tower-service = "0.3" tower-layer = "0.3" -tower-sessions = { version = "0.14", default-features = false, features = [ "axum-core" ] } -http = "1.2" +tower-sessions = { version = "0.14", default-features = false, features = [ + "axum-core", +] } +http = "1.3.1" openidconnect = "4.0" serde = "1.0" futures-util = "0.3" reqwest = { version = "0.12", default-features = false } urlencoding = "2.1" +tracing = "0.1.41" diff --git a/examples/basic/Cargo.toml b/examples/basic/Cargo.toml index 88426a4..86467c3 100644 --- a/examples/basic/Cargo.toml +++ b/examples/basic/Cargo.toml @@ -1,12 +1,16 @@ [package] +edition = "2024" name = "basic" version = "0.1.0" -edition = "2021" [dependencies] -tokio = { version = "1.43", features = ["net", "macros", "rt-multi-thread"] } -axum = { version = "0.8", features = [ "macros" ]} +axum = { version = "0.8", features = ["macros"] } axum-oidc = { path = "./../.." } +dotenvy = "0.15" +openidconnect = "4.0.1" +tokio = { version = "1.48.0", features = ["macros", "net", "rt-multi-thread"] } tower = "0.5" tower-sessions = "0.14" -dotenvy = "0.15" +tracing-subscriber = "0.3.20" +tracing = "0.1.41" +serde = "1.0.228" diff --git a/examples/basic/src/main.rs b/examples/basic/src/main.rs index c7f6831..8c76841 100644 --- a/examples/basic/src/main.rs +++ b/examples/basic/src/main.rs @@ -6,8 +6,9 @@ use axum::{ Router, }; use axum_oidc::{ - error::MiddlewareError, handle_oidc_redirect, ClientId, ClientSecret, EmptyAdditionalClaims, - OidcAuthLayer, OidcClaims, OidcClient, OidcLoginLayer, OidcRpInitiatedLogout, + error::MiddlewareError, handle_oidc_redirect, Audience, ClientId, ClientSecret, + EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcClient, OidcLoginLayer, + OidcRpInitiatedLogout, }; use tokio::net::TcpListener; use tower::ServiceBuilder; @@ -15,9 +16,15 @@ use tower_sessions::{ cookie::{time::Duration, SameSite}, Expiry, MemoryStore, SessionManagerLayer, }; +use tracing::Level; #[tokio::main] -pub async fn main() { +async fn main() { + tracing_subscriber::fmt() + .with_file(true) + .with_line_number(true) + .with_max_level(Level::INFO) + .init(); dotenvy::dotenv().ok(); let issuer = std::env::var("ISSUER").expect("ISSUER env variable"); let client_id = std::env::var("CLIENT_ID").expect("CLIENT_ID env variable"); @@ -39,7 +46,12 @@ pub async fn main() { let mut oidc_client = OidcClient::::builder() .with_default_http_client() .with_redirect_url(Uri::from_static("http://localhost:8080/oidc")) - .with_client_id(ClientId::new(client_id)); + .with_client_id(ClientId::new(client_id)) + .add_scope("profile") + .add_scope("email") + // Optional: add untrusted audiences. If the `aud` claim contains any of these audiences, the token is rejected. + .add_untrusted_audience(Audience::new("123456789".to_string())); + if let Some(client_secret) = client_secret { oidc_client = oidc_client.with_client_secret(ClientSecret::new(client_secret)); } @@ -61,6 +73,9 @@ pub async fn main() { .layer(oidc_auth_service) .layer(session_layer); + tracing::info!("Running on http://localhost:8080"); + tracing::info!("Visit http://localhost:8080/bar or http://localhost:8080/foo"); + let listener = TcpListener::bind("[::]:8080").await.unwrap(); axum::serve(listener, app.into_make_service()) .await diff --git a/src/builder.rs b/src/builder.rs index 6df0d77..be6c733 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use http::Uri; -use openidconnect::{ClientId, ClientSecret, IssuerUrl}; +use openidconnect::{Audience, ClientId, ClientSecret, IssuerUrl}; use crate::{error::Error, AdditionalClaims, Client, OidcClient, ProviderMetadata}; @@ -23,6 +23,7 @@ pub struct Builder, scopes: Vec>, auth_context_class: Option>, + untrusted_audiences: Vec, _ac: PhantomData, } @@ -42,6 +43,7 @@ impl Builder { end_session_endpoint: None, scopes: vec![Box::from("openid")], auth_context_class: None, + untrusted_audiences: Vec::new(), _ac: PhantomData, } } @@ -60,6 +62,7 @@ impl Builder>>) -> Self { self.scopes = scopes.map(|x| x.into()).collect::>(); @@ -71,6 +74,18 @@ impl Builder Self { + self.untrusted_audiences.push(audience); + self + } + + /// replace untrusted audiences + pub fn with_untrusted_audiences(mut self, untrusted_audiences: Vec) -> Self { + self.untrusted_audiences = untrusted_audiences; + self + } } impl Builder { @@ -90,6 +105,7 @@ impl Builder Builder Builder Builder Builder http_client: self.http_client.0, end_session_endpoint: self.end_session_endpoint, auth_context_class: self.auth_context_class, + untrusted_audiences: self.untrusted_audiences, } } } diff --git a/src/error.rs b/src/error.rs index 537d41a..071e911 100644 --- a/src/error.rs +++ b/src/error.rs @@ -41,6 +41,14 @@ pub enum MiddlewareError { #[error("claims verification: {0:?}")] ClaimsVerification(#[from] openidconnect::ClaimsVerificationError), + #[error("user info retrieval: {0:?}")] + UserInfoRetrieval( + #[from] + openidconnect::UserInfoError< + openidconnect::HttpClientError, + >, + ), + #[error("url parsing: {0:?}")] UrlParsing(#[from] openidconnect::url::ParseError), @@ -77,7 +85,7 @@ pub enum MiddlewareError { #[derive(Debug, Error)] pub enum HandlerError { - #[error("the redirect handler got accessed without a valid session")] + #[error("redirect handler accessed without valid session, session cookie missing?")] RedirectedWithoutSession, #[error("csrf token invalid")] @@ -156,24 +164,21 @@ impl IntoResponse for ExtractorError { impl IntoResponse for Error { fn into_response(self) -> axum_core::response::Response { - match self { - _ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(), - } + tracing::error!(error = self.to_string()); + (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response() } } impl IntoResponse for MiddlewareError { fn into_response(self) -> axum_core::response::Response { - match self { - _ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(), - } + tracing::error!(error = self.to_string()); + (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response() } } impl IntoResponse for HandlerError { fn into_response(self) -> axum_core::response::Response { - match self { - _ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(), - } + tracing::error!(error = self.to_string()); + (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response() } } diff --git a/src/extractor.rs b/src/extractor.rs index aeb2ef2..dbb6482 100644 --- a/src/extractor.rs +++ b/src/extractor.rs @@ -7,12 +7,12 @@ use axum_core::{ response::IntoResponse, }; use http::{request::Parts, uri::PathAndQuery, Uri}; -use openidconnect::{core::CoreGenderClaim, ClientId, IdTokenClaims}; +use openidconnect::{core::CoreGenderClaim, ClientId, IdTokenClaims, UserInfoClaims}; /// Extractor for the OpenID Connect Claims. /// /// This Extractor will only return the Claims when the cached session is valid and [`crate::middleware::OidcAuthMiddleware`] is loaded. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct OidcClaims(pub IdTokenClaims); impl FromRequestParts for OidcClaims @@ -213,3 +213,54 @@ impl IntoResponse for OidcRpInitiatedLogout { } } } + +/// Extractor for the OpenID Connect User Info Claims. +/// +/// This Extractor will only return the User Info Claims when the cached session is valid and [`crate::middleware::OidcAuthMiddleware`] is loaded. +#[derive(Clone, Debug)] +pub struct OidcUserInfo(pub UserInfoClaims); + +impl FromRequestParts for OidcUserInfo +where + S: Send + Sync, + AC: AdditionalClaims, +{ + type Rejection = ExtractorError; + + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { + parts + .extensions + .get::() + .cloned() + .ok_or(ExtractorError::Unauthorized) + } +} + +impl OptionalFromRequestParts for OidcUserInfo +where + S: Send + Sync, + AC: AdditionalClaims, +{ + type Rejection = Infallible; + + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result, Self::Rejection> { + Ok(parts.extensions.get::().cloned()) + } +} + +impl Deref for OidcUserInfo { + type Target = UserInfoClaims; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl AsRef> for OidcUserInfo +where + AC: AdditionalClaims, +{ + fn as_ref(&self) -> &UserInfoClaims { + &self.0 + } +} diff --git a/src/handler.rs b/src/handler.rs index c3bbd95..68f4793 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -21,11 +21,14 @@ pub struct OidcQuery { session_state: Option, } +#[tracing::instrument(skip(oidcclient), err)] pub async fn handle_oidc_redirect( session: Session, Extension(oidcclient): Extension>, Query(query): Query, ) -> Result { + tracing::debug!("start handling oidc redirect"); + let mut login_session: OidcSession = session .get(SESSION_KEY) .await? @@ -33,10 +36,12 @@ pub async fn handle_oidc_redirect( // the request has the request headers of the oidc redirect // parse the headers and exchange the code for a valid token + tracing::debug!("validating scrf token"); if login_session.csrf_token.secret() != &query.state { return Err(HandlerError::CsrfTokenInvalid); } + tracing::debug!("obtain token response"); let token_response = oidcclient .client .exchange_code(AuthorizationCode::new(query.code.to_string()))? @@ -47,19 +52,29 @@ pub async fn handle_oidc_redirect( .request_async(&oidcclient.http_client) .await?; + tracing::debug!("extract claims and verify it"); // Extract the ID token claims after verifying its authenticity and nonce. let id_token = token_response .id_token() .ok_or(HandlerError::IdTokenMissing)?; - let id_token_verifier = oidcclient.client.id_token_verifier(); + let id_token_verifier = oidcclient + .client + .id_token_verifier() + .set_other_audience_verifier_fn(|audience| + // Return false (reject) if audience is in list of untrusted audiences + !oidcclient.untrusted_audiences.contains(audience)); let claims = id_token.claims(&id_token_verifier, &login_session.nonce)?; + tracing::debug!("validate access token hash"); validate_access_token_hash( id_token, id_token_verifier, token_response.access_token(), claims, - )?; + ) + .inspect_err(|e| tracing::error!(?e, "Access token hash invalid"))?; + + tracing::debug!("Access token hash validated"); login_session.authenticated = Some(AuthenticatedSession { id_token: id_token.clone(), @@ -70,6 +85,10 @@ pub async fn handle_oidc_redirect( login_session.refresh_token = Some(refresh_token); } + tracing::debug!( + "Inserting session and redirecting to {}", + &login_session.redirect_url + ); let redirect_url = login_session.redirect_url.clone(); session.insert(SESSION_KEY, login_session).await?; @@ -79,6 +98,7 @@ pub async fn handle_oidc_redirect( /// Verify the access token hash to ensure that the access token hasn't been substituted for /// another user's. /// Returns `Ok` when access token is valid +#[tracing::instrument(skip_all, err)] fn validate_access_token_hash( id_token: &IdToken, id_token_verifier: IdTokenVerifier, diff --git a/src/lib.rs b/src/lib.rs index 5251088..fe6aac0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,7 +24,7 @@ mod extractor; mod handler; mod middleware; -pub use extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}; +pub use extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout, OidcUserInfo}; pub use handler::handle_oidc_redirect; pub use middleware::{OidcAuthLayer, OidcAuthMiddleware, OidcLoginLayer, OidcLoginMiddleware}; pub use openidconnect::{Audience, ClientId, ClientSecret}; @@ -108,6 +108,7 @@ pub struct OidcClient { http_client: reqwest::Client, end_session_endpoint: Option, auth_context_class: Option>, + untrusted_audiences: Vec, } /// an empty struct to be used as the default type for the additional claims generic diff --git a/src/middleware.rs b/src/middleware.rs index 4f66352..c7e0e7f 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -17,14 +17,14 @@ use tower_sessions::Session; use openidconnect::{ core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey}, AccessToken, AccessTokenHash, AuthenticationContextClass, CsrfToken, IdTokenClaims, - IdTokenVerifier, Nonce, NonceVerifier, OAuth2TokenResponse, PkceCodeChallenge, RefreshToken, + IdTokenVerifier, Nonce, OAuth2TokenResponse, PkceCodeChallenge, RefreshToken, RequestTokenError::ServerResponse, - Scope, TokenResponse, + Scope, TokenResponse, UserInfoClaims, }; use crate::{ error::MiddlewareError, - extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}, + extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout, OidcUserInfo}, AdditionalClaims, AuthenticatedSession, BoxError, ClearSessionFlag, IdToken, OidcClient, OidcSession, SESSION_KEY, }; @@ -237,6 +237,7 @@ where let inner = self.inner.clone(); let mut inner = std::mem::replace(&mut self.inner, inner); let oidcclient = self.client.clone(); + Box::pin(async move { let (mut parts, body) = request.into_parts(); @@ -254,21 +255,44 @@ where let id_token_claims = login_session.authenticated.as_ref().and_then(|session| { session .id_token - .claims(&oidcclient.client.id_token_verifier(), &login_session.nonce) + .claims( + &oidcclient + .client + .id_token_verifier() + .set_other_audience_verifier_fn(|audience| { + // Return false (reject) if audience is in list of untrusted audiences + !oidcclient.untrusted_audiences.contains(audience) + }), + &login_session.nonce, + ) .ok() .cloned() .map(|claims| (session, claims)) }); if let Some((session, claims)) = id_token_claims { + let user_claims = + get_user_claims(&oidcclient, session.access_token.clone()).await?; // stored id token is valid and can be used - insert_extensions(&mut parts, claims.clone(), &oidcclient, session); + insert_extensions( + &mut parts, + claims.clone(), + user_claims, + &oidcclient, + session, + ); } else if let Some(refresh_token) = login_session.refresh_token.as_ref() { // session is expired but can be refreshed using the refresh_token - if let Some((claims, authenticated_session, refresh_token)) = + if let Some((claims, user_claims, authenticated_session, refresh_token)) = try_refresh_token(&oidcclient, refresh_token, &login_session.nonce).await? { - insert_extensions(&mut parts, claims, &oidcclient, &authenticated_session); + insert_extensions( + &mut parts, + claims, + user_claims.clone(), + &oidcclient, + &authenticated_session, + ); login_session.authenticated = Some(authenticated_session); if let Some(refresh_token) = refresh_token { @@ -311,10 +335,12 @@ where fn insert_extensions( parts: &mut Parts, claims: IdTokenClaims, + user_claims: UserInfoClaims, client: &OidcClient, authenticated_session: &AuthenticatedSession, ) { parts.extensions.insert(OidcClaims(claims)); + parts.extensions.insert(OidcUserInfo(user_claims)); parts.extensions.insert(OidcAccessToken( authenticated_session.access_token.secret().to_string(), )); @@ -356,6 +382,19 @@ fn validate_access_token_hash( } } +async fn get_user_claims( + client: &OidcClient, + access_token: AccessToken, +) -> Result, MiddlewareError> { + client + .client + .user_info(access_token, None) + .map_err(MiddlewareError::Configuration)? + .request_async(&client.http_client) + .await + .map_err(|e| e.into()) +} + async fn try_refresh_token( client: &OidcClient, refresh_token: &RefreshToken, @@ -363,6 +402,7 @@ async fn try_refresh_token( ) -> Result< Option<( IdTokenClaims, + UserInfoClaims, AuthenticatedSession, Option, )>, @@ -380,13 +420,13 @@ async fn try_refresh_token( let id_token = token_response .id_token() .ok_or(MiddlewareError::IdTokenMissing)?; - let id_token_verifier = client.client.id_token_verifier(); - let claims = id_token.claims(&id_token_verifier, |claims_nonce: Option<&Nonce>| { - match claims_nonce { - Some(_) => nonce.verify(claims_nonce), - None => Ok(()), - } - })?; + let id_token_verifier = client + .client + .id_token_verifier() + .set_other_audience_verifier_fn(|audience| + // Return false (reject) if audience is in list of untrusted audiences + !client.untrusted_audiences.contains(audience)); + let claims = id_token.claims(&id_token_verifier, nonce)?; validate_access_token_hash( id_token, @@ -400,8 +440,12 @@ async fn try_refresh_token( access_token: token_response.access_token().clone(), }; + let user_claims = + get_user_claims(client, authenticated_session.access_token.clone()).await?; + Ok(Some(( claims.clone(), + user_claims, authenticated_session, token_response.refresh_token().cloned(), )))