mirror of
https://codeberg.org/pfzetto/axum-oidc
synced 2025-12-08 06:05:16 +01:00
add UserInfoClaims
add allow additional audiences add tracing update basic example apply clippy lints
This commit is contained in:
parent
65cb175603
commit
5952cbff95
9 changed files with 226 additions and 52 deletions
|
|
@ -24,3 +24,4 @@ serde = "1.0"
|
|||
futures-util = "0.3"
|
||||
reqwest = { version = "0.12", default-features = false }
|
||||
urlencoding = "2.1"
|
||||
tracing = "0.1.41"
|
||||
|
|
|
|||
|
|
@ -1,25 +1,32 @@
|
|||
[package]
|
||||
edition = "2021"
|
||||
name = "basic"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.43", features = ["net", "macros", "rt-multi-thread"] }
|
||||
axum = { version = "0.8", features = [ "macros" ]}
|
||||
axum = { version = "0.8", features = ["macros"] }
|
||||
axum-oidc = { path = "./../.." }
|
||||
dotenvy = "0.15"
|
||||
tokio = { version = "1.43", features = ["macros", "net", "rt-multi-thread"] }
|
||||
tower = "0.5"
|
||||
tower-sessions = "0.14"
|
||||
|
||||
dotenvy = "0.15"
|
||||
openidconnect = "4.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0.140"
|
||||
tracing-subscriber = "0.3.19"
|
||||
tracing = "0.1.41"
|
||||
|
||||
[dev-dependencies]
|
||||
env_logger = "0.11"
|
||||
headless_chrome = "1.0"
|
||||
log = "0.4"
|
||||
reqwest = { version = "0.12", features = [
|
||||
"rustls-tls",
|
||||
], default-features = false }
|
||||
testcontainers = "0.23"
|
||||
tokio = { version = "1.43", features = ["rt-multi-thread"] }
|
||||
reqwest = { version = "0.12", features = ["rustls-tls"], default-features = false }
|
||||
env_logger = "0.11"
|
||||
log = "0.4"
|
||||
headless_chrome = "1.0"
|
||||
#see https://github.com/rust-headless-chrome/rust-headless-chrome/issues/535
|
||||
auto_generate_cdp = "=0.4.4"
|
||||
|
|
|
|||
|
|
@ -6,9 +6,10 @@ use axum::{
|
|||
Router,
|
||||
};
|
||||
use axum_oidc::{
|
||||
error::MiddlewareError, handle_oidc_redirect, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims,
|
||||
OidcClient, OidcLoginLayer, OidcRpInitiatedLogout,
|
||||
error::MiddlewareError, handle_oidc_redirect, AdditionalClaims, Audience, Config,
|
||||
OidcAuthLayer, OidcClaims, OidcClient, OidcLoginLayer, OidcRpInitiatedLogout, OidcUserClaims,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::net::TcpListener;
|
||||
use tower::ServiceBuilder;
|
||||
use tower_sessions::{
|
||||
|
|
@ -16,6 +17,15 @@ use tower_sessions::{
|
|||
Expiry, MemoryStore, SessionManagerLayer,
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
//struct MyAdditionalClaims(HashMap<String, serde_json::Value>);
|
||||
struct MyAdditionalClaims {
|
||||
admin: Option<bool>,
|
||||
}
|
||||
|
||||
impl AdditionalClaims for MyAdditionalClaims {}
|
||||
impl openidconnect::AdditionalClaims for MyAdditionalClaims {}
|
||||
|
||||
pub async fn run(issuer: String, client_id: String, client_secret: Option<String>) {
|
||||
let session_store = MemoryStore::default();
|
||||
let session_layer = SessionManagerLayer::new(session_store)
|
||||
|
|
@ -28,30 +38,42 @@ pub async fn run(issuer: String, client_id: String, client_secret: Option<String
|
|||
dbg!(&e);
|
||||
e.into_response()
|
||||
}))
|
||||
.layer(OidcLoginLayer::<EmptyAdditionalClaims>::new());
|
||||
.layer(OidcLoginLayer::<MyAdditionalClaims>::new());
|
||||
|
||||
let mut oidc_client = OidcClient::<EmptyAdditionalClaims>::builder()
|
||||
let mut oidc_client = OidcClient::<MyAdditionalClaims>::builder()
|
||||
.with_default_http_client()
|
||||
.with_redirect_url(Uri::from_static("http://localhost:8080/oidc"))
|
||||
.with_client_id(client_id);
|
||||
.with_redirect_url(Uri::from_static("http://127.0.0.1:8080/oidc"))
|
||||
.with_client_id(client_id)
|
||||
.add_scope("profile")
|
||||
.add_scope("email")
|
||||
.add_scope("urn:zitadel:iam:org:project:id:zitadel:aud");
|
||||
if let Some(client_secret) = client_secret {
|
||||
oidc_client = oidc_client.with_client_secret(client_secret);
|
||||
}
|
||||
let oidc_client = oidc_client.discover(issuer).await.unwrap().build();
|
||||
|
||||
let config = Config {
|
||||
other_audiences: vec![
|
||||
Audience::new("318246545105453932".to_string()),
|
||||
Audience::new("318244871846527852".to_string()),
|
||||
Audience::new("317981086246456313".to_string()),
|
||||
],
|
||||
};
|
||||
|
||||
let oidc_auth_service = ServiceBuilder::new()
|
||||
.layer(HandleErrorLayer::new(|e: MiddlewareError| async {
|
||||
dbg!(&e);
|
||||
e.into_response()
|
||||
}))
|
||||
.layer(OidcAuthLayer::new(oidc_client));
|
||||
.layer(OidcAuthLayer::new(oidc_client, config.clone()));
|
||||
|
||||
let app = Router::new()
|
||||
.route("/foo", get(authenticated))
|
||||
.route("/logout", get(logout))
|
||||
.layer(oidc_login_service)
|
||||
.route("/bar", get(maybe_authenticated))
|
||||
.route("/oidc", any(handle_oidc_redirect::<EmptyAdditionalClaims>))
|
||||
.route("/oidc", any(handle_oidc_redirect::<MyAdditionalClaims>))
|
||||
.with_state(config)
|
||||
.layer(oidc_auth_service)
|
||||
.layer(session_layer);
|
||||
|
||||
|
|
@ -61,24 +83,25 @@ pub async fn run(issuer: String, client_id: String, client_secret: Option<String
|
|||
.unwrap();
|
||||
}
|
||||
|
||||
async fn authenticated(claims: OidcClaims<EmptyAdditionalClaims>) -> impl IntoResponse {
|
||||
async fn authenticated(claims: OidcClaims<MyAdditionalClaims>) -> impl IntoResponse {
|
||||
format!("Hello {}", claims.subject().as_str())
|
||||
}
|
||||
|
||||
#[axum::debug_handler]
|
||||
async fn maybe_authenticated(
|
||||
claims: Result<OidcClaims<EmptyAdditionalClaims>, axum_oidc::error::ExtractorError>,
|
||||
claims: Result<OidcUserClaims<MyAdditionalClaims>, axum_oidc::error::ExtractorError>,
|
||||
) -> impl IntoResponse {
|
||||
if let Ok(claims) = claims {
|
||||
dbg!(&claims);
|
||||
format!(
|
||||
"Hello {}! You are already logged in from another Handler.",
|
||||
claims.subject().as_str()
|
||||
"Hello {:#?}! You are already logged in from another Handler.",
|
||||
claims.name().unwrap().get(None).unwrap().as_str()
|
||||
)
|
||||
} else {
|
||||
"Hello anon!".to_string()
|
||||
"Hello unauthenticated user!".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
async fn logout(logout: OidcRpInitiatedLogout) -> impl IntoResponse {
|
||||
logout.with_post_logout_redirect(Uri::from_static("https://example.com"))
|
||||
logout.with_post_logout_redirect(Uri::from_static("http://127.0.0.1:8080/bar"))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,12 @@
|
|||
use basic::run;
|
||||
use tracing::Level;
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::fmt()
|
||||
.with_file(true)
|
||||
.with_line_number(true)
|
||||
.with_max_level(Level::TRACE)
|
||||
.init();
|
||||
dotenvy::dotenv().ok();
|
||||
let issuer = std::env::var("ISSUER").expect("ISSUER env variable");
|
||||
let client_id = std::env::var("CLIENT_ID").expect("CLIENT_ID env variable");
|
||||
|
|
|
|||
25
src/error.rs
25
src/error.rs
|
|
@ -41,6 +41,14 @@ pub enum MiddlewareError {
|
|||
#[error("claims verification: {0:?}")]
|
||||
ClaimsVerification(#[from] openidconnect::ClaimsVerificationError),
|
||||
|
||||
#[error("user info retrieval: {0:?}")]
|
||||
UserInfoRetrieval(
|
||||
#[from]
|
||||
openidconnect::UserInfoError<
|
||||
openidconnect::HttpClientError<openidconnect::reqwest::Error>,
|
||||
>,
|
||||
),
|
||||
|
||||
#[error("url parsing: {0:?}")]
|
||||
UrlParsing(#[from] openidconnect::url::ParseError),
|
||||
|
||||
|
|
@ -74,7 +82,7 @@ pub enum MiddlewareError {
|
|||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum HandlerError {
|
||||
#[error("the redirect handler got accessed without a valid session")]
|
||||
#[error("redirect handler accessed without valid session, session cookie missing?")]
|
||||
RedirectedWithoutSession,
|
||||
|
||||
#[error("csrf token invalid")]
|
||||
|
|
@ -153,24 +161,21 @@ impl IntoResponse for ExtractorError {
|
|||
|
||||
impl IntoResponse for Error {
|
||||
fn into_response(self) -> axum_core::response::Response {
|
||||
match self {
|
||||
_ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(),
|
||||
}
|
||||
tracing::error!(error = self.to_string());
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response()
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for MiddlewareError {
|
||||
fn into_response(self) -> axum_core::response::Response {
|
||||
match self {
|
||||
_ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(),
|
||||
}
|
||||
tracing::error!(error = self.to_string());
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response()
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for HandlerError {
|
||||
fn into_response(self) -> axum_core::response::Response {
|
||||
match self {
|
||||
_ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(),
|
||||
}
|
||||
tracing::error!(error = self.to_string());
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,12 +7,12 @@ use axum_core::{
|
|||
response::IntoResponse,
|
||||
};
|
||||
use http::{request::Parts, uri::PathAndQuery, Uri};
|
||||
use openidconnect::{core::CoreGenderClaim, IdTokenClaims};
|
||||
use openidconnect::{core::CoreGenderClaim, IdTokenClaims, UserInfoClaims};
|
||||
|
||||
/// Extractor for the OpenID Connect Claims.
|
||||
///
|
||||
/// This Extractor will only return the Claims when the cached session is valid and [`crate::middleware::OidcAuthMiddleware`] is loaded.
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OidcClaims<AC: AdditionalClaims>(pub IdTokenClaims<AC, CoreGenderClaim>);
|
||||
|
||||
impl<S, AC> FromRequestParts<S> for OidcClaims<AC>
|
||||
|
|
@ -213,3 +213,55 @@ impl IntoResponse for OidcRpInitiatedLogout {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Extractor for the OpenID Connect User Info Claims.
|
||||
///
|
||||
/// This Extractor will only return the User Info Claims when the cached session is valid and [`crate::middleware::OidcAuthMiddleware`] is loaded.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OidcUserClaims<AC: AdditionalClaims>(pub UserInfoClaims<AC, CoreGenderClaim>);
|
||||
|
||||
impl<S, AC> FromRequestParts<S> for OidcUserClaims<AC>
|
||||
where
|
||||
S: Send + Sync,
|
||||
AC: AdditionalClaims,
|
||||
{
|
||||
type Rejection = ExtractorError;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
|
||||
parts
|
||||
.extensions
|
||||
.get::<Self>()
|
||||
.cloned()
|
||||
.ok_or(ExtractorError::Unauthorized)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, AC> OptionalFromRequestParts<S> for OidcUserClaims<AC>
|
||||
where
|
||||
S: Send + Sync,
|
||||
AC: AdditionalClaims,
|
||||
{
|
||||
type Rejection = Infallible;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Option<Self>, Self::Rejection> {
|
||||
Ok(parts.extensions.get::<Self>().cloned())
|
||||
}
|
||||
}
|
||||
|
||||
impl<AC: AdditionalClaims> Deref for OidcUserClaims<AC> {
|
||||
type Target = UserInfoClaims<AC, CoreGenderClaim>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<AC> AsRef<UserInfoClaims<AC, CoreGenderClaim>> for OidcUserClaims<AC>
|
||||
where
|
||||
AC: AdditionalClaims,
|
||||
{
|
||||
fn as_ref(&self) -> &UserInfoClaims<AC, CoreGenderClaim> {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
use axum::{extract::Query, response::Redirect, Extension};
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
response::Redirect,
|
||||
Extension,
|
||||
};
|
||||
use openidconnect::{
|
||||
core::{CoreGenderClaim, CoreJsonWebKey},
|
||||
AccessToken, AccessTokenHash, AuthorizationCode, IdTokenClaims, IdTokenVerifier,
|
||||
|
|
@ -8,8 +12,8 @@ use serde::Deserialize;
|
|||
use tower_sessions::Session;
|
||||
|
||||
use crate::{
|
||||
error::HandlerError, AdditionalClaims, AuthenticatedSession, IdToken, OidcClient, OidcSession,
|
||||
SESSION_KEY,
|
||||
error::HandlerError, AdditionalClaims, AuthenticatedSession, Config, IdToken, OidcClient,
|
||||
OidcSession, SESSION_KEY,
|
||||
};
|
||||
|
||||
/// response data of the openid issuer after login
|
||||
|
|
@ -21,11 +25,16 @@ pub struct OidcQuery {
|
|||
session_state: Option<String>,
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(oidcclient), err)]
|
||||
pub async fn handle_oidc_redirect<AC: AdditionalClaims>(
|
||||
session: Session,
|
||||
Extension(oidcclient): Extension<OidcClient<AC>>,
|
||||
State(config): State<Config>,
|
||||
Query(query): Query<OidcQuery>,
|
||||
) -> Result<impl axum::response::IntoResponse, HandlerError> {
|
||||
|
||||
tracing::debug!("start handling oidc redirect");
|
||||
|
||||
let mut login_session: OidcSession<AC> = session
|
||||
.get(SESSION_KEY)
|
||||
.await?
|
||||
|
|
@ -33,10 +42,12 @@ pub async fn handle_oidc_redirect<AC: AdditionalClaims>(
|
|||
// the request has the request headers of the oidc redirect
|
||||
// parse the headers and exchange the code for a valid token
|
||||
|
||||
tracing::debug!("validating scrf token");
|
||||
if login_session.csrf_token.secret() != &query.state {
|
||||
return Err(HandlerError::CsrfTokenInvalid);
|
||||
}
|
||||
|
||||
tracing::debug!("obtain token response");
|
||||
let token_response = oidcclient
|
||||
.client
|
||||
.exchange_code(AuthorizationCode::new(query.code.to_string()))?
|
||||
|
|
@ -47,19 +58,27 @@ pub async fn handle_oidc_redirect<AC: AdditionalClaims>(
|
|||
.request_async(&oidcclient.http_client)
|
||||
.await?;
|
||||
|
||||
tracing::debug!("extract claims and verify it");
|
||||
// Extract the ID token claims after verifying its authenticity and nonce.
|
||||
let id_token = token_response
|
||||
.id_token()
|
||||
.ok_or(HandlerError::IdTokenMissing)?;
|
||||
let id_token_verifier = oidcclient.client.id_token_verifier();
|
||||
let id_token_verifier = oidcclient
|
||||
.client
|
||||
.id_token_verifier()
|
||||
.set_other_audience_verifier_fn(|audience| config.other_audiences.contains(audience));
|
||||
let claims = id_token.claims(&id_token_verifier, &login_session.nonce)?;
|
||||
|
||||
tracing::debug!("validate access token hash");
|
||||
validate_access_token_hash(
|
||||
id_token,
|
||||
id_token_verifier,
|
||||
token_response.access_token(),
|
||||
claims,
|
||||
)?;
|
||||
)
|
||||
.inspect_err(|e| tracing::error!(?e, "Access token hash invalid"))?;
|
||||
|
||||
tracing::debug!("Access token hash validated");
|
||||
|
||||
login_session.authenticated = Some(AuthenticatedSession {
|
||||
id_token: id_token.clone(),
|
||||
|
|
@ -70,6 +89,10 @@ pub async fn handle_oidc_redirect<AC: AdditionalClaims>(
|
|||
login_session.refresh_token = Some(refresh_token);
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
"Inserting session and redirecting to {}",
|
||||
&login_session.redirect_url
|
||||
);
|
||||
let redirect_url = login_session.redirect_url.clone();
|
||||
session.insert(SESSION_KEY, login_session).await?;
|
||||
|
||||
|
|
@ -79,6 +102,7 @@ pub async fn handle_oidc_redirect<AC: AdditionalClaims>(
|
|||
/// Verify the access token hash to ensure that the access token hasn't been substituted for
|
||||
/// another user's.
|
||||
/// Returns `Ok` when access token is valid
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
fn validate_access_token_hash<AC: AdditionalClaims>(
|
||||
id_token: &IdToken<AC>,
|
||||
id_token_verifier: IdTokenVerifier<CoreJsonWebKey>,
|
||||
|
|
|
|||
|
|
@ -24,9 +24,10 @@ mod extractor;
|
|||
mod handler;
|
||||
mod middleware;
|
||||
|
||||
pub use extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout};
|
||||
pub use extractor::{OidcAccessToken, OidcClaims, OidcUserClaims, OidcRpInitiatedLogout};
|
||||
pub use handler::handle_oidc_redirect;
|
||||
pub use middleware::{OidcAuthLayer, OidcAuthMiddleware, OidcLoginLayer, OidcLoginMiddleware};
|
||||
pub use middleware::{Config, OidcAuthLayer, OidcAuthMiddleware, OidcLoginLayer, OidcLoginMiddleware};
|
||||
pub use openidconnect::Audience;
|
||||
|
||||
const SESSION_KEY: &str = "axum-oidc";
|
||||
|
||||
|
|
|
|||
|
|
@ -13,19 +13,24 @@ use tower_sessions::Session;
|
|||
|
||||
use openidconnect::{
|
||||
core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey},
|
||||
AccessToken, AccessTokenHash, AuthenticationContextClass, CsrfToken, IdTokenClaims,
|
||||
AccessToken, AccessTokenHash, Audience, AuthenticationContextClass, CsrfToken, IdTokenClaims,
|
||||
IdTokenVerifier, Nonce, OAuth2TokenResponse, PkceCodeChallenge, RefreshToken,
|
||||
RequestTokenError::ServerResponse,
|
||||
Scope, TokenResponse,
|
||||
Scope, TokenResponse, UserInfoClaims,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
error::MiddlewareError,
|
||||
extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout},
|
||||
extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout, OidcUserClaims},
|
||||
AdditionalClaims, AuthenticatedSession, BoxError, ClearSessionFlag, IdToken, OidcClient,
|
||||
OidcSession, SESSION_KEY,
|
||||
};
|
||||
|
||||
#[derive(Clone, Default, Debug)]
|
||||
pub struct Config {
|
||||
pub other_audiences: Vec<Audience>,
|
||||
}
|
||||
|
||||
/// Layer for the [`OidcLoginMiddleware`].
|
||||
#[derive(Clone, Default)]
|
||||
pub struct OidcLoginLayer<AC>
|
||||
|
|
@ -161,16 +166,17 @@ where
|
|||
AC: AdditionalClaims,
|
||||
{
|
||||
client: OidcClient<AC>,
|
||||
config: Config,
|
||||
}
|
||||
|
||||
impl<AC: AdditionalClaims> OidcAuthLayer<AC> {
|
||||
pub fn new(client: OidcClient<AC>) -> Self {
|
||||
Self { client }
|
||||
pub fn new(client: OidcClient<AC>, config: Config) -> Self {
|
||||
Self { client, config }
|
||||
}
|
||||
}
|
||||
impl<AC: AdditionalClaims> From<OidcClient<AC>> for OidcAuthLayer<AC> {
|
||||
fn from(value: OidcClient<AC>) -> Self {
|
||||
Self::new(value)
|
||||
Self::new(value, Config::default())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -184,6 +190,7 @@ where
|
|||
OidcAuthMiddleware {
|
||||
inner,
|
||||
client: self.client.clone(),
|
||||
config: self.config.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -199,6 +206,7 @@ where
|
|||
{
|
||||
inner: I,
|
||||
client: OidcClient<AC>,
|
||||
config: Config,
|
||||
}
|
||||
|
||||
impl<I, AC, B> Service<Request<B>> for OidcAuthMiddleware<I, AC>
|
||||
|
|
@ -223,7 +231,9 @@ where
|
|||
fn call(&mut self, request: Request<B>) -> Self::Future {
|
||||
let inner = self.inner.clone();
|
||||
let mut inner = std::mem::replace(&mut self.inner, inner);
|
||||
let other_audiences = self.config.other_audiences.clone();
|
||||
let oidcclient = self.client.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
let (mut parts, body) = request.into_parts();
|
||||
|
||||
|
|
@ -241,21 +251,43 @@ where
|
|||
let id_token_claims = login_session.authenticated.as_ref().and_then(|session| {
|
||||
session
|
||||
.id_token
|
||||
.claims(&oidcclient.client.id_token_verifier(), &login_session.nonce)
|
||||
.claims(
|
||||
&oidcclient
|
||||
.client
|
||||
.id_token_verifier()
|
||||
.set_other_audience_verifier_fn(|audience| {
|
||||
other_audiences.contains(audience)
|
||||
}),
|
||||
&login_session.nonce,
|
||||
)
|
||||
.ok()
|
||||
.cloned()
|
||||
.map(|claims| (session, claims))
|
||||
});
|
||||
|
||||
if let Some((session, claims)) = id_token_claims {
|
||||
let user_claims =
|
||||
get_user_claims(&oidcclient, session.access_token.clone()).await?;
|
||||
// stored id token is valid and can be used
|
||||
insert_extensions(&mut parts, claims.clone(), &oidcclient, session);
|
||||
insert_extensions(
|
||||
&mut parts,
|
||||
claims.clone(),
|
||||
user_claims,
|
||||
&oidcclient,
|
||||
session,
|
||||
);
|
||||
} else if let Some(refresh_token) = login_session.refresh_token.as_ref() {
|
||||
// session is expired but can be refreshed using the refresh_token
|
||||
if let Some((claims, authenticated_session, refresh_token)) =
|
||||
if let Some((claims, user_claims, authenticated_session, refresh_token)) =
|
||||
try_refresh_token(&oidcclient, refresh_token, &login_session.nonce).await?
|
||||
{
|
||||
insert_extensions(&mut parts, claims, &oidcclient, &authenticated_session);
|
||||
insert_extensions(
|
||||
&mut parts,
|
||||
claims,
|
||||
user_claims.clone(),
|
||||
&oidcclient,
|
||||
&authenticated_session,
|
||||
);
|
||||
login_session.authenticated = Some(authenticated_session);
|
||||
|
||||
if let Some(refresh_token) = refresh_token {
|
||||
|
|
@ -297,10 +329,12 @@ where
|
|||
fn insert_extensions<AC: AdditionalClaims>(
|
||||
parts: &mut Parts,
|
||||
claims: IdTokenClaims<AC, CoreGenderClaim>,
|
||||
user_claims: UserInfoClaims<AC, CoreGenderClaim>,
|
||||
client: &OidcClient<AC>,
|
||||
authenticated_session: &AuthenticatedSession<AC>,
|
||||
) {
|
||||
parts.extensions.insert(OidcClaims(claims));
|
||||
parts.extensions.insert(OidcUserClaims(user_claims));
|
||||
parts.extensions.insert(OidcAccessToken(
|
||||
authenticated_session.access_token.secret().to_string(),
|
||||
));
|
||||
|
|
@ -342,6 +376,19 @@ fn validate_access_token_hash<AC: AdditionalClaims>(
|
|||
}
|
||||
}
|
||||
|
||||
async fn get_user_claims<AC: AdditionalClaims>(
|
||||
client: &OidcClient<AC>,
|
||||
access_token: AccessToken,
|
||||
) -> Result<UserInfoClaims<AC, CoreGenderClaim>, MiddlewareError> {
|
||||
client
|
||||
.client
|
||||
.user_info(access_token, None)
|
||||
.map_err(MiddlewareError::Configuration)?
|
||||
.request_async(&client.http_client)
|
||||
.await
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
async fn try_refresh_token<AC: AdditionalClaims>(
|
||||
client: &OidcClient<AC>,
|
||||
refresh_token: &RefreshToken,
|
||||
|
|
@ -349,6 +396,7 @@ async fn try_refresh_token<AC: AdditionalClaims>(
|
|||
) -> Result<
|
||||
Option<(
|
||||
IdTokenClaims<AC, CoreGenderClaim>,
|
||||
UserInfoClaims<AC, CoreGenderClaim>,
|
||||
AuthenticatedSession<AC>,
|
||||
Option<RefreshToken>,
|
||||
)>,
|
||||
|
|
@ -366,7 +414,10 @@ async fn try_refresh_token<AC: AdditionalClaims>(
|
|||
let id_token = token_response
|
||||
.id_token()
|
||||
.ok_or(MiddlewareError::IdTokenMissing)?;
|
||||
let id_token_verifier = client.client.id_token_verifier();
|
||||
let id_token_verifier = client
|
||||
.client
|
||||
.id_token_verifier()
|
||||
.require_audience_match(false);
|
||||
let claims = id_token.claims(&id_token_verifier, nonce)?;
|
||||
|
||||
validate_access_token_hash(
|
||||
|
|
@ -381,8 +432,12 @@ async fn try_refresh_token<AC: AdditionalClaims>(
|
|||
access_token: token_response.access_token().clone(),
|
||||
};
|
||||
|
||||
let user_claims =
|
||||
get_user_claims(client, authenticated_session.access_token.clone()).await?;
|
||||
|
||||
Ok(Some((
|
||||
claims.clone(),
|
||||
user_claims,
|
||||
authenticated_session,
|
||||
token_response.refresh_token().cloned(),
|
||||
)))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue