add UserInfoClaims, add untrusted_audiences, add tracing

This commit is contained in:
JuliDi 2025-11-24 11:18:54 +01:00
parent 6280ad62cc
commit 094e9e5ff6
No known key found for this signature in database
GPG key ID: E1E90AE563D09D63
9 changed files with 210 additions and 43 deletions

View file

@ -3,24 +3,30 @@ name = "axum-oidc"
description = "A wrapper for the openidconnect crate for axum" description = "A wrapper for the openidconnect crate for axum"
version = "0.6.0" version = "0.6.0"
edition = "2021" edition = "2021"
authors = [ "Paul Z <info@pfz4.de>" ] authors = ["Paul Z <info@pfz4.de>"]
readme = "README.md" readme = "README.md"
repository = "https://github.com/pfz4/axum-oidc" repository = "https://github.com/pfz4/axum-oidc"
license = "MPL-2.0" license = "MPL-2.0"
keywords = [ "axum", "oidc", "openidconnect", "authentication" ] keywords = ["axum", "oidc", "openidconnect", "authentication"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
thiserror = "2.0" thiserror = "2.0"
axum-core = "0.5" axum-core = "0.5"
axum = { version = "0.8", default-features = false, features = [ "query", "original-uri" ] } axum = { version = "0.8", default-features = false, features = [
"query",
"original-uri",
] }
tower-service = "0.3" tower-service = "0.3"
tower-layer = "0.3" tower-layer = "0.3"
tower-sessions = { version = "0.14", default-features = false, features = [ "axum-core" ] } tower-sessions = { version = "0.14", default-features = false, features = [
http = "1.2" "axum-core",
] }
http = "1.3.1"
openidconnect = "4.0" openidconnect = "4.0"
serde = "1.0" serde = "1.0"
futures-util = "0.3" futures-util = "0.3"
reqwest = { version = "0.12", default-features = false } reqwest = { version = "0.12", default-features = false }
urlencoding = "2.1" urlencoding = "2.1"
tracing = "0.1.41"

View file

@ -1,12 +1,16 @@
[package] [package]
edition = "2024"
name = "basic" name = "basic"
version = "0.1.0" version = "0.1.0"
edition = "2021"
[dependencies] [dependencies]
tokio = { version = "1.43", features = ["net", "macros", "rt-multi-thread"] } axum = { version = "0.8", features = ["macros"] }
axum = { version = "0.8", features = [ "macros" ]}
axum-oidc = { path = "./../.." } axum-oidc = { path = "./../.." }
dotenvy = "0.15"
openidconnect = "4.0.1"
tokio = { version = "1.48.0", features = ["macros", "net", "rt-multi-thread"] }
tower = "0.5" tower = "0.5"
tower-sessions = "0.14" tower-sessions = "0.14"
dotenvy = "0.15" tracing-subscriber = "0.3.20"
tracing = "0.1.41"
serde = "1.0.228"

View file

@ -6,8 +6,9 @@ use axum::{
Router, Router,
}; };
use axum_oidc::{ use axum_oidc::{
error::MiddlewareError, handle_oidc_redirect, ClientId, ClientSecret, EmptyAdditionalClaims, error::MiddlewareError, handle_oidc_redirect, Audience, ClientId, ClientSecret,
OidcAuthLayer, OidcClaims, OidcClient, OidcLoginLayer, OidcRpInitiatedLogout, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcClient, OidcLoginLayer,
OidcRpInitiatedLogout,
}; };
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tower::ServiceBuilder; use tower::ServiceBuilder;
@ -15,9 +16,15 @@ use tower_sessions::{
cookie::{time::Duration, SameSite}, cookie::{time::Duration, SameSite},
Expiry, MemoryStore, SessionManagerLayer, Expiry, MemoryStore, SessionManagerLayer,
}; };
use tracing::Level;
#[tokio::main] #[tokio::main]
pub async fn main() { async fn main() {
tracing_subscriber::fmt()
.with_file(true)
.with_line_number(true)
.with_max_level(Level::INFO)
.init();
dotenvy::dotenv().ok(); dotenvy::dotenv().ok();
let issuer = std::env::var("ISSUER").expect("ISSUER env variable"); let issuer = std::env::var("ISSUER").expect("ISSUER env variable");
let client_id = std::env::var("CLIENT_ID").expect("CLIENT_ID env variable"); let client_id = std::env::var("CLIENT_ID").expect("CLIENT_ID env variable");
@ -39,7 +46,12 @@ pub async fn main() {
let mut oidc_client = OidcClient::<EmptyAdditionalClaims>::builder() let mut oidc_client = OidcClient::<EmptyAdditionalClaims>::builder()
.with_default_http_client() .with_default_http_client()
.with_redirect_url(Uri::from_static("http://localhost:8080/oidc")) .with_redirect_url(Uri::from_static("http://localhost:8080/oidc"))
.with_client_id(ClientId::new(client_id)); .with_client_id(ClientId::new(client_id))
.add_scope("profile")
.add_scope("email")
// Optional: add untrusted audiences. If the `aud` claim contains any of these audiences, the token is rejected.
.add_untrusted_audience(Audience::new("123456789".to_string()));
if let Some(client_secret) = client_secret { if let Some(client_secret) = client_secret {
oidc_client = oidc_client.with_client_secret(ClientSecret::new(client_secret)); oidc_client = oidc_client.with_client_secret(ClientSecret::new(client_secret));
} }
@ -61,6 +73,9 @@ pub async fn main() {
.layer(oidc_auth_service) .layer(oidc_auth_service)
.layer(session_layer); .layer(session_layer);
tracing::info!("Running on http://localhost:8080");
tracing::info!("Visit http://localhost:8080/bar or http://localhost:8080/foo");
let listener = TcpListener::bind("[::]:8080").await.unwrap(); let listener = TcpListener::bind("[::]:8080").await.unwrap();
axum::serve(listener, app.into_make_service()) axum::serve(listener, app.into_make_service())
.await .await

View file

@ -1,7 +1,7 @@
use std::marker::PhantomData; use std::marker::PhantomData;
use http::Uri; use http::Uri;
use openidconnect::{ClientId, ClientSecret, IssuerUrl}; use openidconnect::{Audience, ClientId, ClientSecret, IssuerUrl};
use crate::{error::Error, AdditionalClaims, Client, OidcClient, ProviderMetadata}; use crate::{error::Error, AdditionalClaims, Client, OidcClient, ProviderMetadata};
@ -23,6 +23,7 @@ pub struct Builder<AC: AdditionalClaims, Credentials, Client, HttpClient, Redire
end_session_endpoint: Option<Uri>, end_session_endpoint: Option<Uri>,
scopes: Vec<Box<str>>, scopes: Vec<Box<str>>,
auth_context_class: Option<Box<str>>, auth_context_class: Option<Box<str>>,
untrusted_audiences: Vec<Audience>,
_ac: PhantomData<AC>, _ac: PhantomData<AC>,
} }
@ -42,6 +43,7 @@ impl<AC: AdditionalClaims> Builder<AC, (), (), (), ()> {
end_session_endpoint: None, end_session_endpoint: None,
scopes: vec![Box::from("openid")], scopes: vec![Box::from("openid")],
auth_context_class: None, auth_context_class: None,
untrusted_audiences: Vec::new(),
_ac: PhantomData, _ac: PhantomData,
} }
} }
@ -60,6 +62,7 @@ impl<AC: AdditionalClaims, CREDS, CLIENT, HTTP, RURL> Builder<AC, CREDS, CLIENT,
self.scopes.push(scope.into()); self.scopes.push(scope.into());
self self
} }
/// replace scopes (including default) /// replace scopes (including default)
pub fn with_scopes(mut self, scopes: impl Iterator<Item = impl Into<Box<str>>>) -> Self { pub fn with_scopes(mut self, scopes: impl Iterator<Item = impl Into<Box<str>>>) -> Self {
self.scopes = scopes.map(|x| x.into()).collect::<Vec<_>>(); self.scopes = scopes.map(|x| x.into()).collect::<Vec<_>>();
@ -71,6 +74,18 @@ impl<AC: AdditionalClaims, CREDS, CLIENT, HTTP, RURL> Builder<AC, CREDS, CLIENT,
self.auth_context_class = Some(acr.into()); self.auth_context_class = Some(acr.into());
self self
} }
/// add a an untrusted audience to existing untrusted audiences
pub fn add_untrusted_audience(mut self, audience: Audience) -> Self {
self.untrusted_audiences.push(audience);
self
}
/// replace untrusted audiences
pub fn with_untrusted_audiences(mut self, untrusted_audiences: Vec<Audience>) -> Self {
self.untrusted_audiences = untrusted_audiences;
self
}
} }
impl<AC: AdditionalClaims, CLIENT, HTTP, RURL> Builder<AC, (), CLIENT, HTTP, RURL> { impl<AC: AdditionalClaims, CLIENT, HTTP, RURL> Builder<AC, (), CLIENT, HTTP, RURL> {
@ -90,6 +105,7 @@ impl<AC: AdditionalClaims, CLIENT, HTTP, RURL> Builder<AC, (), CLIENT, HTTP, RUR
end_session_endpoint: self.end_session_endpoint, end_session_endpoint: self.end_session_endpoint,
scopes: self.scopes, scopes: self.scopes,
auth_context_class: self.auth_context_class, auth_context_class: self.auth_context_class,
untrusted_audiences: self.untrusted_audiences,
_ac: PhantomData, _ac: PhantomData,
} }
} }
@ -117,6 +133,7 @@ impl<AC: AdditionalClaims, CREDS, CLIENT, RURL> Builder<AC, CREDS, CLIENT, (), R
end_session_endpoint: self.end_session_endpoint, end_session_endpoint: self.end_session_endpoint,
scopes: self.scopes, scopes: self.scopes,
auth_context_class: self.auth_context_class, auth_context_class: self.auth_context_class,
untrusted_audiences: self.untrusted_audiences,
_ac: self._ac, _ac: self._ac,
} }
} }
@ -130,6 +147,7 @@ impl<AC: AdditionalClaims, CREDS, CLIENT, RURL> Builder<AC, CREDS, CLIENT, (), R
end_session_endpoint: self.end_session_endpoint, end_session_endpoint: self.end_session_endpoint,
scopes: self.scopes, scopes: self.scopes,
auth_context_class: self.auth_context_class, auth_context_class: self.auth_context_class,
untrusted_audiences: self.untrusted_audiences,
_ac: self._ac, _ac: self._ac,
} }
} }
@ -148,6 +166,7 @@ impl<AC: AdditionalClaims, CREDS, CLIENT, HCLIENT> Builder<AC, CREDS, CLIENT, HC
end_session_endpoint: self.end_session_endpoint, end_session_endpoint: self.end_session_endpoint,
scopes: self.scopes, scopes: self.scopes,
auth_context_class: self.auth_context_class, auth_context_class: self.auth_context_class,
untrusted_audiences: self.untrusted_audiences,
_ac: self._ac, _ac: self._ac,
} }
} }
@ -186,6 +205,7 @@ impl<AC: AdditionalClaims> Builder<AC, ClientCredentials, (), HttpClient, Redire
end_session_endpoint, end_session_endpoint,
scopes: self.scopes, scopes: self.scopes,
auth_context_class: self.auth_context_class, auth_context_class: self.auth_context_class,
untrusted_audiences: self.untrusted_audiences,
_ac: self._ac, _ac: self._ac,
}) })
} }
@ -217,6 +237,7 @@ impl<AC: AdditionalClaims>
http_client: self.http_client.0, http_client: self.http_client.0,
end_session_endpoint: self.end_session_endpoint, end_session_endpoint: self.end_session_endpoint,
auth_context_class: self.auth_context_class, auth_context_class: self.auth_context_class,
untrusted_audiences: self.untrusted_audiences,
} }
} }
} }

View file

@ -41,6 +41,14 @@ pub enum MiddlewareError {
#[error("claims verification: {0:?}")] #[error("claims verification: {0:?}")]
ClaimsVerification(#[from] openidconnect::ClaimsVerificationError), ClaimsVerification(#[from] openidconnect::ClaimsVerificationError),
#[error("user info retrieval: {0:?}")]
UserInfoRetrieval(
#[from]
openidconnect::UserInfoError<
openidconnect::HttpClientError<openidconnect::reqwest::Error>,
>,
),
#[error("url parsing: {0:?}")] #[error("url parsing: {0:?}")]
UrlParsing(#[from] openidconnect::url::ParseError), UrlParsing(#[from] openidconnect::url::ParseError),
@ -77,7 +85,7 @@ pub enum MiddlewareError {
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum HandlerError { pub enum HandlerError {
#[error("the redirect handler got accessed without a valid session")] #[error("redirect handler accessed without valid session, session cookie missing?")]
RedirectedWithoutSession, RedirectedWithoutSession,
#[error("csrf token invalid")] #[error("csrf token invalid")]
@ -156,24 +164,21 @@ impl IntoResponse for ExtractorError {
impl IntoResponse for Error { impl IntoResponse for Error {
fn into_response(self) -> axum_core::response::Response { fn into_response(self) -> axum_core::response::Response {
match self { tracing::error!(error = self.to_string());
_ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(), (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response()
}
} }
} }
impl IntoResponse for MiddlewareError { impl IntoResponse for MiddlewareError {
fn into_response(self) -> axum_core::response::Response { fn into_response(self) -> axum_core::response::Response {
match self { tracing::error!(error = self.to_string());
_ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(), (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response()
}
} }
} }
impl IntoResponse for HandlerError { impl IntoResponse for HandlerError {
fn into_response(self) -> axum_core::response::Response { fn into_response(self) -> axum_core::response::Response {
match self { tracing::error!(error = self.to_string());
_ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(), (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response()
}
} }
} }

View file

@ -7,12 +7,12 @@ use axum_core::{
response::IntoResponse, response::IntoResponse,
}; };
use http::{request::Parts, uri::PathAndQuery, Uri}; use http::{request::Parts, uri::PathAndQuery, Uri};
use openidconnect::{core::CoreGenderClaim, ClientId, IdTokenClaims}; use openidconnect::{core::CoreGenderClaim, ClientId, IdTokenClaims, UserInfoClaims};
/// Extractor for the OpenID Connect Claims. /// Extractor for the OpenID Connect Claims.
/// ///
/// This Extractor will only return the Claims when the cached session is valid and [`crate::middleware::OidcAuthMiddleware`] is loaded. /// This Extractor will only return the Claims when the cached session is valid and [`crate::middleware::OidcAuthMiddleware`] is loaded.
#[derive(Clone)] #[derive(Clone, Debug)]
pub struct OidcClaims<AC: AdditionalClaims>(pub IdTokenClaims<AC, CoreGenderClaim>); pub struct OidcClaims<AC: AdditionalClaims>(pub IdTokenClaims<AC, CoreGenderClaim>);
impl<S, AC> FromRequestParts<S> for OidcClaims<AC> impl<S, AC> FromRequestParts<S> for OidcClaims<AC>
@ -213,3 +213,54 @@ impl IntoResponse for OidcRpInitiatedLogout {
} }
} }
} }
/// Extractor for the OpenID Connect User Info Claims.
///
/// This Extractor will only return the User Info Claims when the cached session is valid and [`crate::middleware::OidcAuthMiddleware`] is loaded.
#[derive(Clone, Debug)]
pub struct OidcUserInfo<AC: AdditionalClaims>(pub UserInfoClaims<AC, CoreGenderClaim>);
impl<S, AC> FromRequestParts<S> for OidcUserInfo<AC>
where
S: Send + Sync,
AC: AdditionalClaims,
{
type Rejection = ExtractorError;
async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<Self>()
.cloned()
.ok_or(ExtractorError::Unauthorized)
}
}
impl<S, AC> OptionalFromRequestParts<S> for OidcUserInfo<AC>
where
S: Send + Sync,
AC: AdditionalClaims,
{
type Rejection = Infallible;
async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Option<Self>, Self::Rejection> {
Ok(parts.extensions.get::<Self>().cloned())
}
}
impl<AC: AdditionalClaims> Deref for OidcUserInfo<AC> {
type Target = UserInfoClaims<AC, CoreGenderClaim>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<AC> AsRef<UserInfoClaims<AC, CoreGenderClaim>> for OidcUserInfo<AC>
where
AC: AdditionalClaims,
{
fn as_ref(&self) -> &UserInfoClaims<AC, CoreGenderClaim> {
&self.0
}
}

View file

@ -21,11 +21,14 @@ pub struct OidcQuery {
session_state: Option<String>, session_state: Option<String>,
} }
#[tracing::instrument(skip(oidcclient), err)]
pub async fn handle_oidc_redirect<AC: AdditionalClaims>( pub async fn handle_oidc_redirect<AC: AdditionalClaims>(
session: Session, session: Session,
Extension(oidcclient): Extension<OidcClient<AC>>, Extension(oidcclient): Extension<OidcClient<AC>>,
Query(query): Query<OidcQuery>, Query(query): Query<OidcQuery>,
) -> Result<impl axum::response::IntoResponse, HandlerError> { ) -> Result<impl axum::response::IntoResponse, HandlerError> {
tracing::debug!("start handling oidc redirect");
let mut login_session: OidcSession<AC> = session let mut login_session: OidcSession<AC> = session
.get(SESSION_KEY) .get(SESSION_KEY)
.await? .await?
@ -33,10 +36,12 @@ pub async fn handle_oidc_redirect<AC: AdditionalClaims>(
// the request has the request headers of the oidc redirect // the request has the request headers of the oidc redirect
// parse the headers and exchange the code for a valid token // parse the headers and exchange the code for a valid token
tracing::debug!("validating scrf token");
if login_session.csrf_token.secret() != &query.state { if login_session.csrf_token.secret() != &query.state {
return Err(HandlerError::CsrfTokenInvalid); return Err(HandlerError::CsrfTokenInvalid);
} }
tracing::debug!("obtain token response");
let token_response = oidcclient let token_response = oidcclient
.client .client
.exchange_code(AuthorizationCode::new(query.code.to_string()))? .exchange_code(AuthorizationCode::new(query.code.to_string()))?
@ -47,19 +52,29 @@ pub async fn handle_oidc_redirect<AC: AdditionalClaims>(
.request_async(&oidcclient.http_client) .request_async(&oidcclient.http_client)
.await?; .await?;
tracing::debug!("extract claims and verify it");
// Extract the ID token claims after verifying its authenticity and nonce. // Extract the ID token claims after verifying its authenticity and nonce.
let id_token = token_response let id_token = token_response
.id_token() .id_token()
.ok_or(HandlerError::IdTokenMissing)?; .ok_or(HandlerError::IdTokenMissing)?;
let id_token_verifier = oidcclient.client.id_token_verifier(); let id_token_verifier = oidcclient
.client
.id_token_verifier()
.set_other_audience_verifier_fn(|audience|
// Return false (reject) if audience is in list of untrusted audiences
!oidcclient.untrusted_audiences.contains(audience));
let claims = id_token.claims(&id_token_verifier, &login_session.nonce)?; let claims = id_token.claims(&id_token_verifier, &login_session.nonce)?;
tracing::debug!("validate access token hash");
validate_access_token_hash( validate_access_token_hash(
id_token, id_token,
id_token_verifier, id_token_verifier,
token_response.access_token(), token_response.access_token(),
claims, claims,
)?; )
.inspect_err(|e| tracing::error!(?e, "Access token hash invalid"))?;
tracing::debug!("Access token hash validated");
login_session.authenticated = Some(AuthenticatedSession { login_session.authenticated = Some(AuthenticatedSession {
id_token: id_token.clone(), id_token: id_token.clone(),
@ -70,6 +85,10 @@ pub async fn handle_oidc_redirect<AC: AdditionalClaims>(
login_session.refresh_token = Some(refresh_token); login_session.refresh_token = Some(refresh_token);
} }
tracing::debug!(
"Inserting session and redirecting to {}",
&login_session.redirect_url
);
let redirect_url = login_session.redirect_url.clone(); let redirect_url = login_session.redirect_url.clone();
session.insert(SESSION_KEY, login_session).await?; session.insert(SESSION_KEY, login_session).await?;
@ -79,6 +98,7 @@ pub async fn handle_oidc_redirect<AC: AdditionalClaims>(
/// Verify the access token hash to ensure that the access token hasn't been substituted for /// Verify the access token hash to ensure that the access token hasn't been substituted for
/// another user's. /// another user's.
/// Returns `Ok` when access token is valid /// Returns `Ok` when access token is valid
#[tracing::instrument(skip_all, err)]
fn validate_access_token_hash<AC: AdditionalClaims>( fn validate_access_token_hash<AC: AdditionalClaims>(
id_token: &IdToken<AC>, id_token: &IdToken<AC>,
id_token_verifier: IdTokenVerifier<CoreJsonWebKey>, id_token_verifier: IdTokenVerifier<CoreJsonWebKey>,

View file

@ -24,7 +24,7 @@ mod extractor;
mod handler; mod handler;
mod middleware; mod middleware;
pub use extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}; pub use extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout, OidcUserInfo};
pub use handler::handle_oidc_redirect; pub use handler::handle_oidc_redirect;
pub use middleware::{OidcAuthLayer, OidcAuthMiddleware, OidcLoginLayer, OidcLoginMiddleware}; pub use middleware::{OidcAuthLayer, OidcAuthMiddleware, OidcLoginLayer, OidcLoginMiddleware};
pub use openidconnect::{Audience, ClientId, ClientSecret}; pub use openidconnect::{Audience, ClientId, ClientSecret};
@ -108,6 +108,7 @@ pub struct OidcClient<AC: AdditionalClaims> {
http_client: reqwest::Client, http_client: reqwest::Client,
end_session_endpoint: Option<Uri>, end_session_endpoint: Option<Uri>,
auth_context_class: Option<Box<str>>, auth_context_class: Option<Box<str>>,
untrusted_audiences: Vec<Audience>,
} }
/// an empty struct to be used as the default type for the additional claims generic /// an empty struct to be used as the default type for the additional claims generic

View file

@ -17,14 +17,14 @@ use tower_sessions::Session;
use openidconnect::{ use openidconnect::{
core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey}, core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey},
AccessToken, AccessTokenHash, AuthenticationContextClass, CsrfToken, IdTokenClaims, AccessToken, AccessTokenHash, AuthenticationContextClass, CsrfToken, IdTokenClaims,
IdTokenVerifier, Nonce, NonceVerifier, OAuth2TokenResponse, PkceCodeChallenge, RefreshToken, IdTokenVerifier, Nonce, OAuth2TokenResponse, PkceCodeChallenge, RefreshToken,
RequestTokenError::ServerResponse, RequestTokenError::ServerResponse,
Scope, TokenResponse, Scope, TokenResponse, UserInfoClaims,
}; };
use crate::{ use crate::{
error::MiddlewareError, error::MiddlewareError,
extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}, extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout, OidcUserInfo},
AdditionalClaims, AuthenticatedSession, BoxError, ClearSessionFlag, IdToken, OidcClient, AdditionalClaims, AuthenticatedSession, BoxError, ClearSessionFlag, IdToken, OidcClient,
OidcSession, SESSION_KEY, OidcSession, SESSION_KEY,
}; };
@ -237,6 +237,7 @@ where
let inner = self.inner.clone(); let inner = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, inner); let mut inner = std::mem::replace(&mut self.inner, inner);
let oidcclient = self.client.clone(); let oidcclient = self.client.clone();
Box::pin(async move { Box::pin(async move {
let (mut parts, body) = request.into_parts(); let (mut parts, body) = request.into_parts();
@ -254,21 +255,44 @@ where
let id_token_claims = login_session.authenticated.as_ref().and_then(|session| { let id_token_claims = login_session.authenticated.as_ref().and_then(|session| {
session session
.id_token .id_token
.claims(&oidcclient.client.id_token_verifier(), &login_session.nonce) .claims(
&oidcclient
.client
.id_token_verifier()
.set_other_audience_verifier_fn(|audience| {
// Return false (reject) if audience is in list of untrusted audiences
!oidcclient.untrusted_audiences.contains(audience)
}),
&login_session.nonce,
)
.ok() .ok()
.cloned() .cloned()
.map(|claims| (session, claims)) .map(|claims| (session, claims))
}); });
if let Some((session, claims)) = id_token_claims { if let Some((session, claims)) = id_token_claims {
let user_claims =
get_user_claims(&oidcclient, session.access_token.clone()).await?;
// stored id token is valid and can be used // stored id token is valid and can be used
insert_extensions(&mut parts, claims.clone(), &oidcclient, session); insert_extensions(
&mut parts,
claims.clone(),
user_claims,
&oidcclient,
session,
);
} else if let Some(refresh_token) = login_session.refresh_token.as_ref() { } else if let Some(refresh_token) = login_session.refresh_token.as_ref() {
// session is expired but can be refreshed using the refresh_token // session is expired but can be refreshed using the refresh_token
if let Some((claims, authenticated_session, refresh_token)) = if let Some((claims, user_claims, authenticated_session, refresh_token)) =
try_refresh_token(&oidcclient, refresh_token, &login_session.nonce).await? try_refresh_token(&oidcclient, refresh_token, &login_session.nonce).await?
{ {
insert_extensions(&mut parts, claims, &oidcclient, &authenticated_session); insert_extensions(
&mut parts,
claims,
user_claims.clone(),
&oidcclient,
&authenticated_session,
);
login_session.authenticated = Some(authenticated_session); login_session.authenticated = Some(authenticated_session);
if let Some(refresh_token) = refresh_token { if let Some(refresh_token) = refresh_token {
@ -311,10 +335,12 @@ where
fn insert_extensions<AC: AdditionalClaims>( fn insert_extensions<AC: AdditionalClaims>(
parts: &mut Parts, parts: &mut Parts,
claims: IdTokenClaims<AC, CoreGenderClaim>, claims: IdTokenClaims<AC, CoreGenderClaim>,
user_claims: UserInfoClaims<AC, CoreGenderClaim>,
client: &OidcClient<AC>, client: &OidcClient<AC>,
authenticated_session: &AuthenticatedSession<AC>, authenticated_session: &AuthenticatedSession<AC>,
) { ) {
parts.extensions.insert(OidcClaims(claims)); parts.extensions.insert(OidcClaims(claims));
parts.extensions.insert(OidcUserInfo(user_claims));
parts.extensions.insert(OidcAccessToken( parts.extensions.insert(OidcAccessToken(
authenticated_session.access_token.secret().to_string(), authenticated_session.access_token.secret().to_string(),
)); ));
@ -356,6 +382,19 @@ fn validate_access_token_hash<AC: AdditionalClaims>(
} }
} }
async fn get_user_claims<AC: AdditionalClaims>(
client: &OidcClient<AC>,
access_token: AccessToken,
) -> Result<UserInfoClaims<AC, CoreGenderClaim>, MiddlewareError> {
client
.client
.user_info(access_token, None)
.map_err(MiddlewareError::Configuration)?
.request_async(&client.http_client)
.await
.map_err(|e| e.into())
}
async fn try_refresh_token<AC: AdditionalClaims>( async fn try_refresh_token<AC: AdditionalClaims>(
client: &OidcClient<AC>, client: &OidcClient<AC>,
refresh_token: &RefreshToken, refresh_token: &RefreshToken,
@ -363,6 +402,7 @@ async fn try_refresh_token<AC: AdditionalClaims>(
) -> Result< ) -> Result<
Option<( Option<(
IdTokenClaims<AC, CoreGenderClaim>, IdTokenClaims<AC, CoreGenderClaim>,
UserInfoClaims<AC, CoreGenderClaim>,
AuthenticatedSession<AC>, AuthenticatedSession<AC>,
Option<RefreshToken>, Option<RefreshToken>,
)>, )>,
@ -380,13 +420,13 @@ async fn try_refresh_token<AC: AdditionalClaims>(
let id_token = token_response let id_token = token_response
.id_token() .id_token()
.ok_or(MiddlewareError::IdTokenMissing)?; .ok_or(MiddlewareError::IdTokenMissing)?;
let id_token_verifier = client.client.id_token_verifier(); let id_token_verifier = client
let claims = id_token.claims(&id_token_verifier, |claims_nonce: Option<&Nonce>| { .client
match claims_nonce { .id_token_verifier()
Some(_) => nonce.verify(claims_nonce), .set_other_audience_verifier_fn(|audience|
None => Ok(()), // Return false (reject) if audience is in list of untrusted audiences
} !client.untrusted_audiences.contains(audience));
})?; let claims = id_token.claims(&id_token_verifier, nonce)?;
validate_access_token_hash( validate_access_token_hash(
id_token, id_token,
@ -400,8 +440,12 @@ async fn try_refresh_token<AC: AdditionalClaims>(
access_token: token_response.access_token().clone(), access_token: token_response.access_token().clone(),
}; };
let user_claims =
get_user_claims(client, authenticated_session.access_token.clone()).await?;
Ok(Some(( Ok(Some((
claims.clone(), claims.clone(),
user_claims,
authenticated_session, authenticated_session,
token_response.refresh_token().cloned(), token_response.refresh_token().cloned(),
))) )))