mirror of
https://github.com/pfzetto/axum-oidc.git
synced 2025-12-07 16:35:17 +01:00
add UserInfoClaims, add untrusted_audiences, add tracing
This commit is contained in:
parent
6280ad62cc
commit
094e9e5ff6
9 changed files with 210 additions and 43 deletions
16
Cargo.toml
16
Cargo.toml
|
|
@ -3,24 +3,30 @@ name = "axum-oidc"
|
||||||
description = "A wrapper for the openidconnect crate for axum"
|
description = "A wrapper for the openidconnect crate for axum"
|
||||||
version = "0.6.0"
|
version = "0.6.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = [ "Paul Z <info@pfz4.de>" ]
|
authors = ["Paul Z <info@pfz4.de>"]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
repository = "https://github.com/pfz4/axum-oidc"
|
repository = "https://github.com/pfz4/axum-oidc"
|
||||||
license = "MPL-2.0"
|
license = "MPL-2.0"
|
||||||
keywords = [ "axum", "oidc", "openidconnect", "authentication" ]
|
keywords = ["axum", "oidc", "openidconnect", "authentication"]
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
thiserror = "2.0"
|
thiserror = "2.0"
|
||||||
axum-core = "0.5"
|
axum-core = "0.5"
|
||||||
axum = { version = "0.8", default-features = false, features = [ "query", "original-uri" ] }
|
axum = { version = "0.8", default-features = false, features = [
|
||||||
|
"query",
|
||||||
|
"original-uri",
|
||||||
|
] }
|
||||||
tower-service = "0.3"
|
tower-service = "0.3"
|
||||||
tower-layer = "0.3"
|
tower-layer = "0.3"
|
||||||
tower-sessions = { version = "0.14", default-features = false, features = [ "axum-core" ] }
|
tower-sessions = { version = "0.14", default-features = false, features = [
|
||||||
http = "1.2"
|
"axum-core",
|
||||||
|
] }
|
||||||
|
http = "1.3.1"
|
||||||
openidconnect = "4.0"
|
openidconnect = "4.0"
|
||||||
serde = "1.0"
|
serde = "1.0"
|
||||||
futures-util = "0.3"
|
futures-util = "0.3"
|
||||||
reqwest = { version = "0.12", default-features = false }
|
reqwest = { version = "0.12", default-features = false }
|
||||||
urlencoding = "2.1"
|
urlencoding = "2.1"
|
||||||
|
tracing = "0.1.41"
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,16 @@
|
||||||
[package]
|
[package]
|
||||||
|
edition = "2024"
|
||||||
name = "basic"
|
name = "basic"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
tokio = { version = "1.43", features = ["net", "macros", "rt-multi-thread"] }
|
axum = { version = "0.8", features = ["macros"] }
|
||||||
axum = { version = "0.8", features = [ "macros" ]}
|
|
||||||
axum-oidc = { path = "./../.." }
|
axum-oidc = { path = "./../.." }
|
||||||
|
dotenvy = "0.15"
|
||||||
|
openidconnect = "4.0.1"
|
||||||
|
tokio = { version = "1.48.0", features = ["macros", "net", "rt-multi-thread"] }
|
||||||
tower = "0.5"
|
tower = "0.5"
|
||||||
tower-sessions = "0.14"
|
tower-sessions = "0.14"
|
||||||
dotenvy = "0.15"
|
tracing-subscriber = "0.3.20"
|
||||||
|
tracing = "0.1.41"
|
||||||
|
serde = "1.0.228"
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,9 @@ use axum::{
|
||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
use axum_oidc::{
|
use axum_oidc::{
|
||||||
error::MiddlewareError, handle_oidc_redirect, ClientId, ClientSecret, EmptyAdditionalClaims,
|
error::MiddlewareError, handle_oidc_redirect, Audience, ClientId, ClientSecret,
|
||||||
OidcAuthLayer, OidcClaims, OidcClient, OidcLoginLayer, OidcRpInitiatedLogout,
|
EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcClient, OidcLoginLayer,
|
||||||
|
OidcRpInitiatedLogout,
|
||||||
};
|
};
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
use tower::ServiceBuilder;
|
use tower::ServiceBuilder;
|
||||||
|
|
@ -15,9 +16,15 @@ use tower_sessions::{
|
||||||
cookie::{time::Duration, SameSite},
|
cookie::{time::Duration, SameSite},
|
||||||
Expiry, MemoryStore, SessionManagerLayer,
|
Expiry, MemoryStore, SessionManagerLayer,
|
||||||
};
|
};
|
||||||
|
use tracing::Level;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
pub async fn main() {
|
async fn main() {
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_file(true)
|
||||||
|
.with_line_number(true)
|
||||||
|
.with_max_level(Level::INFO)
|
||||||
|
.init();
|
||||||
dotenvy::dotenv().ok();
|
dotenvy::dotenv().ok();
|
||||||
let issuer = std::env::var("ISSUER").expect("ISSUER env variable");
|
let issuer = std::env::var("ISSUER").expect("ISSUER env variable");
|
||||||
let client_id = std::env::var("CLIENT_ID").expect("CLIENT_ID env variable");
|
let client_id = std::env::var("CLIENT_ID").expect("CLIENT_ID env variable");
|
||||||
|
|
@ -39,7 +46,12 @@ pub async fn main() {
|
||||||
let mut oidc_client = OidcClient::<EmptyAdditionalClaims>::builder()
|
let mut oidc_client = OidcClient::<EmptyAdditionalClaims>::builder()
|
||||||
.with_default_http_client()
|
.with_default_http_client()
|
||||||
.with_redirect_url(Uri::from_static("http://localhost:8080/oidc"))
|
.with_redirect_url(Uri::from_static("http://localhost:8080/oidc"))
|
||||||
.with_client_id(ClientId::new(client_id));
|
.with_client_id(ClientId::new(client_id))
|
||||||
|
.add_scope("profile")
|
||||||
|
.add_scope("email")
|
||||||
|
// Optional: add untrusted audiences. If the `aud` claim contains any of these audiences, the token is rejected.
|
||||||
|
.add_untrusted_audience(Audience::new("123456789".to_string()));
|
||||||
|
|
||||||
if let Some(client_secret) = client_secret {
|
if let Some(client_secret) = client_secret {
|
||||||
oidc_client = oidc_client.with_client_secret(ClientSecret::new(client_secret));
|
oidc_client = oidc_client.with_client_secret(ClientSecret::new(client_secret));
|
||||||
}
|
}
|
||||||
|
|
@ -61,6 +73,9 @@ pub async fn main() {
|
||||||
.layer(oidc_auth_service)
|
.layer(oidc_auth_service)
|
||||||
.layer(session_layer);
|
.layer(session_layer);
|
||||||
|
|
||||||
|
tracing::info!("Running on http://localhost:8080");
|
||||||
|
tracing::info!("Visit http://localhost:8080/bar or http://localhost:8080/foo");
|
||||||
|
|
||||||
let listener = TcpListener::bind("[::]:8080").await.unwrap();
|
let listener = TcpListener::bind("[::]:8080").await.unwrap();
|
||||||
axum::serve(listener, app.into_make_service())
|
axum::serve(listener, app.into_make_service())
|
||||||
.await
|
.await
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
use http::Uri;
|
use http::Uri;
|
||||||
use openidconnect::{ClientId, ClientSecret, IssuerUrl};
|
use openidconnect::{Audience, ClientId, ClientSecret, IssuerUrl};
|
||||||
|
|
||||||
use crate::{error::Error, AdditionalClaims, Client, OidcClient, ProviderMetadata};
|
use crate::{error::Error, AdditionalClaims, Client, OidcClient, ProviderMetadata};
|
||||||
|
|
||||||
|
|
@ -23,6 +23,7 @@ pub struct Builder<AC: AdditionalClaims, Credentials, Client, HttpClient, Redire
|
||||||
end_session_endpoint: Option<Uri>,
|
end_session_endpoint: Option<Uri>,
|
||||||
scopes: Vec<Box<str>>,
|
scopes: Vec<Box<str>>,
|
||||||
auth_context_class: Option<Box<str>>,
|
auth_context_class: Option<Box<str>>,
|
||||||
|
untrusted_audiences: Vec<Audience>,
|
||||||
_ac: PhantomData<AC>,
|
_ac: PhantomData<AC>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -42,6 +43,7 @@ impl<AC: AdditionalClaims> Builder<AC, (), (), (), ()> {
|
||||||
end_session_endpoint: None,
|
end_session_endpoint: None,
|
||||||
scopes: vec![Box::from("openid")],
|
scopes: vec![Box::from("openid")],
|
||||||
auth_context_class: None,
|
auth_context_class: None,
|
||||||
|
untrusted_audiences: Vec::new(),
|
||||||
_ac: PhantomData,
|
_ac: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -60,6 +62,7 @@ impl<AC: AdditionalClaims, CREDS, CLIENT, HTTP, RURL> Builder<AC, CREDS, CLIENT,
|
||||||
self.scopes.push(scope.into());
|
self.scopes.push(scope.into());
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// replace scopes (including default)
|
/// replace scopes (including default)
|
||||||
pub fn with_scopes(mut self, scopes: impl Iterator<Item = impl Into<Box<str>>>) -> Self {
|
pub fn with_scopes(mut self, scopes: impl Iterator<Item = impl Into<Box<str>>>) -> Self {
|
||||||
self.scopes = scopes.map(|x| x.into()).collect::<Vec<_>>();
|
self.scopes = scopes.map(|x| x.into()).collect::<Vec<_>>();
|
||||||
|
|
@ -71,6 +74,18 @@ impl<AC: AdditionalClaims, CREDS, CLIENT, HTTP, RURL> Builder<AC, CREDS, CLIENT,
|
||||||
self.auth_context_class = Some(acr.into());
|
self.auth_context_class = Some(acr.into());
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// add a an untrusted audience to existing untrusted audiences
|
||||||
|
pub fn add_untrusted_audience(mut self, audience: Audience) -> Self {
|
||||||
|
self.untrusted_audiences.push(audience);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// replace untrusted audiences
|
||||||
|
pub fn with_untrusted_audiences(mut self, untrusted_audiences: Vec<Audience>) -> Self {
|
||||||
|
self.untrusted_audiences = untrusted_audiences;
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<AC: AdditionalClaims, CLIENT, HTTP, RURL> Builder<AC, (), CLIENT, HTTP, RURL> {
|
impl<AC: AdditionalClaims, CLIENT, HTTP, RURL> Builder<AC, (), CLIENT, HTTP, RURL> {
|
||||||
|
|
@ -90,6 +105,7 @@ impl<AC: AdditionalClaims, CLIENT, HTTP, RURL> Builder<AC, (), CLIENT, HTTP, RUR
|
||||||
end_session_endpoint: self.end_session_endpoint,
|
end_session_endpoint: self.end_session_endpoint,
|
||||||
scopes: self.scopes,
|
scopes: self.scopes,
|
||||||
auth_context_class: self.auth_context_class,
|
auth_context_class: self.auth_context_class,
|
||||||
|
untrusted_audiences: self.untrusted_audiences,
|
||||||
_ac: PhantomData,
|
_ac: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -117,6 +133,7 @@ impl<AC: AdditionalClaims, CREDS, CLIENT, RURL> Builder<AC, CREDS, CLIENT, (), R
|
||||||
end_session_endpoint: self.end_session_endpoint,
|
end_session_endpoint: self.end_session_endpoint,
|
||||||
scopes: self.scopes,
|
scopes: self.scopes,
|
||||||
auth_context_class: self.auth_context_class,
|
auth_context_class: self.auth_context_class,
|
||||||
|
untrusted_audiences: self.untrusted_audiences,
|
||||||
_ac: self._ac,
|
_ac: self._ac,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -130,6 +147,7 @@ impl<AC: AdditionalClaims, CREDS, CLIENT, RURL> Builder<AC, CREDS, CLIENT, (), R
|
||||||
end_session_endpoint: self.end_session_endpoint,
|
end_session_endpoint: self.end_session_endpoint,
|
||||||
scopes: self.scopes,
|
scopes: self.scopes,
|
||||||
auth_context_class: self.auth_context_class,
|
auth_context_class: self.auth_context_class,
|
||||||
|
untrusted_audiences: self.untrusted_audiences,
|
||||||
_ac: self._ac,
|
_ac: self._ac,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -148,6 +166,7 @@ impl<AC: AdditionalClaims, CREDS, CLIENT, HCLIENT> Builder<AC, CREDS, CLIENT, HC
|
||||||
end_session_endpoint: self.end_session_endpoint,
|
end_session_endpoint: self.end_session_endpoint,
|
||||||
scopes: self.scopes,
|
scopes: self.scopes,
|
||||||
auth_context_class: self.auth_context_class,
|
auth_context_class: self.auth_context_class,
|
||||||
|
untrusted_audiences: self.untrusted_audiences,
|
||||||
_ac: self._ac,
|
_ac: self._ac,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -186,6 +205,7 @@ impl<AC: AdditionalClaims> Builder<AC, ClientCredentials, (), HttpClient, Redire
|
||||||
end_session_endpoint,
|
end_session_endpoint,
|
||||||
scopes: self.scopes,
|
scopes: self.scopes,
|
||||||
auth_context_class: self.auth_context_class,
|
auth_context_class: self.auth_context_class,
|
||||||
|
untrusted_audiences: self.untrusted_audiences,
|
||||||
_ac: self._ac,
|
_ac: self._ac,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -217,6 +237,7 @@ impl<AC: AdditionalClaims>
|
||||||
http_client: self.http_client.0,
|
http_client: self.http_client.0,
|
||||||
end_session_endpoint: self.end_session_endpoint,
|
end_session_endpoint: self.end_session_endpoint,
|
||||||
auth_context_class: self.auth_context_class,
|
auth_context_class: self.auth_context_class,
|
||||||
|
untrusted_audiences: self.untrusted_audiences,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
25
src/error.rs
25
src/error.rs
|
|
@ -41,6 +41,14 @@ pub enum MiddlewareError {
|
||||||
#[error("claims verification: {0:?}")]
|
#[error("claims verification: {0:?}")]
|
||||||
ClaimsVerification(#[from] openidconnect::ClaimsVerificationError),
|
ClaimsVerification(#[from] openidconnect::ClaimsVerificationError),
|
||||||
|
|
||||||
|
#[error("user info retrieval: {0:?}")]
|
||||||
|
UserInfoRetrieval(
|
||||||
|
#[from]
|
||||||
|
openidconnect::UserInfoError<
|
||||||
|
openidconnect::HttpClientError<openidconnect::reqwest::Error>,
|
||||||
|
>,
|
||||||
|
),
|
||||||
|
|
||||||
#[error("url parsing: {0:?}")]
|
#[error("url parsing: {0:?}")]
|
||||||
UrlParsing(#[from] openidconnect::url::ParseError),
|
UrlParsing(#[from] openidconnect::url::ParseError),
|
||||||
|
|
||||||
|
|
@ -77,7 +85,7 @@ pub enum MiddlewareError {
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum HandlerError {
|
pub enum HandlerError {
|
||||||
#[error("the redirect handler got accessed without a valid session")]
|
#[error("redirect handler accessed without valid session, session cookie missing?")]
|
||||||
RedirectedWithoutSession,
|
RedirectedWithoutSession,
|
||||||
|
|
||||||
#[error("csrf token invalid")]
|
#[error("csrf token invalid")]
|
||||||
|
|
@ -156,24 +164,21 @@ impl IntoResponse for ExtractorError {
|
||||||
|
|
||||||
impl IntoResponse for Error {
|
impl IntoResponse for Error {
|
||||||
fn into_response(self) -> axum_core::response::Response {
|
fn into_response(self) -> axum_core::response::Response {
|
||||||
match self {
|
tracing::error!(error = self.to_string());
|
||||||
_ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(),
|
(StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response()
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntoResponse for MiddlewareError {
|
impl IntoResponse for MiddlewareError {
|
||||||
fn into_response(self) -> axum_core::response::Response {
|
fn into_response(self) -> axum_core::response::Response {
|
||||||
match self {
|
tracing::error!(error = self.to_string());
|
||||||
_ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(),
|
(StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response()
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntoResponse for HandlerError {
|
impl IntoResponse for HandlerError {
|
||||||
fn into_response(self) -> axum_core::response::Response {
|
fn into_response(self) -> axum_core::response::Response {
|
||||||
match self {
|
tracing::error!(error = self.to_string());
|
||||||
_ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(),
|
(StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response()
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,12 +7,12 @@ use axum_core::{
|
||||||
response::IntoResponse,
|
response::IntoResponse,
|
||||||
};
|
};
|
||||||
use http::{request::Parts, uri::PathAndQuery, Uri};
|
use http::{request::Parts, uri::PathAndQuery, Uri};
|
||||||
use openidconnect::{core::CoreGenderClaim, ClientId, IdTokenClaims};
|
use openidconnect::{core::CoreGenderClaim, ClientId, IdTokenClaims, UserInfoClaims};
|
||||||
|
|
||||||
/// Extractor for the OpenID Connect Claims.
|
/// Extractor for the OpenID Connect Claims.
|
||||||
///
|
///
|
||||||
/// This Extractor will only return the Claims when the cached session is valid and [`crate::middleware::OidcAuthMiddleware`] is loaded.
|
/// This Extractor will only return the Claims when the cached session is valid and [`crate::middleware::OidcAuthMiddleware`] is loaded.
|
||||||
#[derive(Clone)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct OidcClaims<AC: AdditionalClaims>(pub IdTokenClaims<AC, CoreGenderClaim>);
|
pub struct OidcClaims<AC: AdditionalClaims>(pub IdTokenClaims<AC, CoreGenderClaim>);
|
||||||
|
|
||||||
impl<S, AC> FromRequestParts<S> for OidcClaims<AC>
|
impl<S, AC> FromRequestParts<S> for OidcClaims<AC>
|
||||||
|
|
@ -213,3 +213,54 @@ impl IntoResponse for OidcRpInitiatedLogout {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Extractor for the OpenID Connect User Info Claims.
|
||||||
|
///
|
||||||
|
/// This Extractor will only return the User Info Claims when the cached session is valid and [`crate::middleware::OidcAuthMiddleware`] is loaded.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct OidcUserInfo<AC: AdditionalClaims>(pub UserInfoClaims<AC, CoreGenderClaim>);
|
||||||
|
|
||||||
|
impl<S, AC> FromRequestParts<S> for OidcUserInfo<AC>
|
||||||
|
where
|
||||||
|
S: Send + Sync,
|
||||||
|
AC: AdditionalClaims,
|
||||||
|
{
|
||||||
|
type Rejection = ExtractorError;
|
||||||
|
|
||||||
|
async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
|
||||||
|
parts
|
||||||
|
.extensions
|
||||||
|
.get::<Self>()
|
||||||
|
.cloned()
|
||||||
|
.ok_or(ExtractorError::Unauthorized)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, AC> OptionalFromRequestParts<S> for OidcUserInfo<AC>
|
||||||
|
where
|
||||||
|
S: Send + Sync,
|
||||||
|
AC: AdditionalClaims,
|
||||||
|
{
|
||||||
|
type Rejection = Infallible;
|
||||||
|
|
||||||
|
async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Option<Self>, Self::Rejection> {
|
||||||
|
Ok(parts.extensions.get::<Self>().cloned())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<AC: AdditionalClaims> Deref for OidcUserInfo<AC> {
|
||||||
|
type Target = UserInfoClaims<AC, CoreGenderClaim>;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<AC> AsRef<UserInfoClaims<AC, CoreGenderClaim>> for OidcUserInfo<AC>
|
||||||
|
where
|
||||||
|
AC: AdditionalClaims,
|
||||||
|
{
|
||||||
|
fn as_ref(&self) -> &UserInfoClaims<AC, CoreGenderClaim> {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -21,11 +21,14 @@ pub struct OidcQuery {
|
||||||
session_state: Option<String>,
|
session_state: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(oidcclient), err)]
|
||||||
pub async fn handle_oidc_redirect<AC: AdditionalClaims>(
|
pub async fn handle_oidc_redirect<AC: AdditionalClaims>(
|
||||||
session: Session,
|
session: Session,
|
||||||
Extension(oidcclient): Extension<OidcClient<AC>>,
|
Extension(oidcclient): Extension<OidcClient<AC>>,
|
||||||
Query(query): Query<OidcQuery>,
|
Query(query): Query<OidcQuery>,
|
||||||
) -> Result<impl axum::response::IntoResponse, HandlerError> {
|
) -> Result<impl axum::response::IntoResponse, HandlerError> {
|
||||||
|
tracing::debug!("start handling oidc redirect");
|
||||||
|
|
||||||
let mut login_session: OidcSession<AC> = session
|
let mut login_session: OidcSession<AC> = session
|
||||||
.get(SESSION_KEY)
|
.get(SESSION_KEY)
|
||||||
.await?
|
.await?
|
||||||
|
|
@ -33,10 +36,12 @@ pub async fn handle_oidc_redirect<AC: AdditionalClaims>(
|
||||||
// the request has the request headers of the oidc redirect
|
// the request has the request headers of the oidc redirect
|
||||||
// parse the headers and exchange the code for a valid token
|
// parse the headers and exchange the code for a valid token
|
||||||
|
|
||||||
|
tracing::debug!("validating scrf token");
|
||||||
if login_session.csrf_token.secret() != &query.state {
|
if login_session.csrf_token.secret() != &query.state {
|
||||||
return Err(HandlerError::CsrfTokenInvalid);
|
return Err(HandlerError::CsrfTokenInvalid);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tracing::debug!("obtain token response");
|
||||||
let token_response = oidcclient
|
let token_response = oidcclient
|
||||||
.client
|
.client
|
||||||
.exchange_code(AuthorizationCode::new(query.code.to_string()))?
|
.exchange_code(AuthorizationCode::new(query.code.to_string()))?
|
||||||
|
|
@ -47,19 +52,29 @@ pub async fn handle_oidc_redirect<AC: AdditionalClaims>(
|
||||||
.request_async(&oidcclient.http_client)
|
.request_async(&oidcclient.http_client)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
tracing::debug!("extract claims and verify it");
|
||||||
// Extract the ID token claims after verifying its authenticity and nonce.
|
// Extract the ID token claims after verifying its authenticity and nonce.
|
||||||
let id_token = token_response
|
let id_token = token_response
|
||||||
.id_token()
|
.id_token()
|
||||||
.ok_or(HandlerError::IdTokenMissing)?;
|
.ok_or(HandlerError::IdTokenMissing)?;
|
||||||
let id_token_verifier = oidcclient.client.id_token_verifier();
|
let id_token_verifier = oidcclient
|
||||||
|
.client
|
||||||
|
.id_token_verifier()
|
||||||
|
.set_other_audience_verifier_fn(|audience|
|
||||||
|
// Return false (reject) if audience is in list of untrusted audiences
|
||||||
|
!oidcclient.untrusted_audiences.contains(audience));
|
||||||
let claims = id_token.claims(&id_token_verifier, &login_session.nonce)?;
|
let claims = id_token.claims(&id_token_verifier, &login_session.nonce)?;
|
||||||
|
|
||||||
|
tracing::debug!("validate access token hash");
|
||||||
validate_access_token_hash(
|
validate_access_token_hash(
|
||||||
id_token,
|
id_token,
|
||||||
id_token_verifier,
|
id_token_verifier,
|
||||||
token_response.access_token(),
|
token_response.access_token(),
|
||||||
claims,
|
claims,
|
||||||
)?;
|
)
|
||||||
|
.inspect_err(|e| tracing::error!(?e, "Access token hash invalid"))?;
|
||||||
|
|
||||||
|
tracing::debug!("Access token hash validated");
|
||||||
|
|
||||||
login_session.authenticated = Some(AuthenticatedSession {
|
login_session.authenticated = Some(AuthenticatedSession {
|
||||||
id_token: id_token.clone(),
|
id_token: id_token.clone(),
|
||||||
|
|
@ -70,6 +85,10 @@ pub async fn handle_oidc_redirect<AC: AdditionalClaims>(
|
||||||
login_session.refresh_token = Some(refresh_token);
|
login_session.refresh_token = Some(refresh_token);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
"Inserting session and redirecting to {}",
|
||||||
|
&login_session.redirect_url
|
||||||
|
);
|
||||||
let redirect_url = login_session.redirect_url.clone();
|
let redirect_url = login_session.redirect_url.clone();
|
||||||
session.insert(SESSION_KEY, login_session).await?;
|
session.insert(SESSION_KEY, login_session).await?;
|
||||||
|
|
||||||
|
|
@ -79,6 +98,7 @@ pub async fn handle_oidc_redirect<AC: AdditionalClaims>(
|
||||||
/// Verify the access token hash to ensure that the access token hasn't been substituted for
|
/// Verify the access token hash to ensure that the access token hasn't been substituted for
|
||||||
/// another user's.
|
/// another user's.
|
||||||
/// Returns `Ok` when access token is valid
|
/// Returns `Ok` when access token is valid
|
||||||
|
#[tracing::instrument(skip_all, err)]
|
||||||
fn validate_access_token_hash<AC: AdditionalClaims>(
|
fn validate_access_token_hash<AC: AdditionalClaims>(
|
||||||
id_token: &IdToken<AC>,
|
id_token: &IdToken<AC>,
|
||||||
id_token_verifier: IdTokenVerifier<CoreJsonWebKey>,
|
id_token_verifier: IdTokenVerifier<CoreJsonWebKey>,
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ mod extractor;
|
||||||
mod handler;
|
mod handler;
|
||||||
mod middleware;
|
mod middleware;
|
||||||
|
|
||||||
pub use extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout};
|
pub use extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout, OidcUserInfo};
|
||||||
pub use handler::handle_oidc_redirect;
|
pub use handler::handle_oidc_redirect;
|
||||||
pub use middleware::{OidcAuthLayer, OidcAuthMiddleware, OidcLoginLayer, OidcLoginMiddleware};
|
pub use middleware::{OidcAuthLayer, OidcAuthMiddleware, OidcLoginLayer, OidcLoginMiddleware};
|
||||||
pub use openidconnect::{Audience, ClientId, ClientSecret};
|
pub use openidconnect::{Audience, ClientId, ClientSecret};
|
||||||
|
|
@ -108,6 +108,7 @@ pub struct OidcClient<AC: AdditionalClaims> {
|
||||||
http_client: reqwest::Client,
|
http_client: reqwest::Client,
|
||||||
end_session_endpoint: Option<Uri>,
|
end_session_endpoint: Option<Uri>,
|
||||||
auth_context_class: Option<Box<str>>,
|
auth_context_class: Option<Box<str>>,
|
||||||
|
untrusted_audiences: Vec<Audience>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// an empty struct to be used as the default type for the additional claims generic
|
/// an empty struct to be used as the default type for the additional claims generic
|
||||||
|
|
|
||||||
|
|
@ -17,14 +17,14 @@ use tower_sessions::Session;
|
||||||
use openidconnect::{
|
use openidconnect::{
|
||||||
core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey},
|
core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey},
|
||||||
AccessToken, AccessTokenHash, AuthenticationContextClass, CsrfToken, IdTokenClaims,
|
AccessToken, AccessTokenHash, AuthenticationContextClass, CsrfToken, IdTokenClaims,
|
||||||
IdTokenVerifier, Nonce, NonceVerifier, OAuth2TokenResponse, PkceCodeChallenge, RefreshToken,
|
IdTokenVerifier, Nonce, OAuth2TokenResponse, PkceCodeChallenge, RefreshToken,
|
||||||
RequestTokenError::ServerResponse,
|
RequestTokenError::ServerResponse,
|
||||||
Scope, TokenResponse,
|
Scope, TokenResponse, UserInfoClaims,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
error::MiddlewareError,
|
error::MiddlewareError,
|
||||||
extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout},
|
extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout, OidcUserInfo},
|
||||||
AdditionalClaims, AuthenticatedSession, BoxError, ClearSessionFlag, IdToken, OidcClient,
|
AdditionalClaims, AuthenticatedSession, BoxError, ClearSessionFlag, IdToken, OidcClient,
|
||||||
OidcSession, SESSION_KEY,
|
OidcSession, SESSION_KEY,
|
||||||
};
|
};
|
||||||
|
|
@ -237,6 +237,7 @@ where
|
||||||
let inner = self.inner.clone();
|
let inner = self.inner.clone();
|
||||||
let mut inner = std::mem::replace(&mut self.inner, inner);
|
let mut inner = std::mem::replace(&mut self.inner, inner);
|
||||||
let oidcclient = self.client.clone();
|
let oidcclient = self.client.clone();
|
||||||
|
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
let (mut parts, body) = request.into_parts();
|
let (mut parts, body) = request.into_parts();
|
||||||
|
|
||||||
|
|
@ -254,21 +255,44 @@ where
|
||||||
let id_token_claims = login_session.authenticated.as_ref().and_then(|session| {
|
let id_token_claims = login_session.authenticated.as_ref().and_then(|session| {
|
||||||
session
|
session
|
||||||
.id_token
|
.id_token
|
||||||
.claims(&oidcclient.client.id_token_verifier(), &login_session.nonce)
|
.claims(
|
||||||
|
&oidcclient
|
||||||
|
.client
|
||||||
|
.id_token_verifier()
|
||||||
|
.set_other_audience_verifier_fn(|audience| {
|
||||||
|
// Return false (reject) if audience is in list of untrusted audiences
|
||||||
|
!oidcclient.untrusted_audiences.contains(audience)
|
||||||
|
}),
|
||||||
|
&login_session.nonce,
|
||||||
|
)
|
||||||
.ok()
|
.ok()
|
||||||
.cloned()
|
.cloned()
|
||||||
.map(|claims| (session, claims))
|
.map(|claims| (session, claims))
|
||||||
});
|
});
|
||||||
|
|
||||||
if let Some((session, claims)) = id_token_claims {
|
if let Some((session, claims)) = id_token_claims {
|
||||||
|
let user_claims =
|
||||||
|
get_user_claims(&oidcclient, session.access_token.clone()).await?;
|
||||||
// stored id token is valid and can be used
|
// stored id token is valid and can be used
|
||||||
insert_extensions(&mut parts, claims.clone(), &oidcclient, session);
|
insert_extensions(
|
||||||
|
&mut parts,
|
||||||
|
claims.clone(),
|
||||||
|
user_claims,
|
||||||
|
&oidcclient,
|
||||||
|
session,
|
||||||
|
);
|
||||||
} else if let Some(refresh_token) = login_session.refresh_token.as_ref() {
|
} else if let Some(refresh_token) = login_session.refresh_token.as_ref() {
|
||||||
// session is expired but can be refreshed using the refresh_token
|
// session is expired but can be refreshed using the refresh_token
|
||||||
if let Some((claims, authenticated_session, refresh_token)) =
|
if let Some((claims, user_claims, authenticated_session, refresh_token)) =
|
||||||
try_refresh_token(&oidcclient, refresh_token, &login_session.nonce).await?
|
try_refresh_token(&oidcclient, refresh_token, &login_session.nonce).await?
|
||||||
{
|
{
|
||||||
insert_extensions(&mut parts, claims, &oidcclient, &authenticated_session);
|
insert_extensions(
|
||||||
|
&mut parts,
|
||||||
|
claims,
|
||||||
|
user_claims.clone(),
|
||||||
|
&oidcclient,
|
||||||
|
&authenticated_session,
|
||||||
|
);
|
||||||
login_session.authenticated = Some(authenticated_session);
|
login_session.authenticated = Some(authenticated_session);
|
||||||
|
|
||||||
if let Some(refresh_token) = refresh_token {
|
if let Some(refresh_token) = refresh_token {
|
||||||
|
|
@ -311,10 +335,12 @@ where
|
||||||
fn insert_extensions<AC: AdditionalClaims>(
|
fn insert_extensions<AC: AdditionalClaims>(
|
||||||
parts: &mut Parts,
|
parts: &mut Parts,
|
||||||
claims: IdTokenClaims<AC, CoreGenderClaim>,
|
claims: IdTokenClaims<AC, CoreGenderClaim>,
|
||||||
|
user_claims: UserInfoClaims<AC, CoreGenderClaim>,
|
||||||
client: &OidcClient<AC>,
|
client: &OidcClient<AC>,
|
||||||
authenticated_session: &AuthenticatedSession<AC>,
|
authenticated_session: &AuthenticatedSession<AC>,
|
||||||
) {
|
) {
|
||||||
parts.extensions.insert(OidcClaims(claims));
|
parts.extensions.insert(OidcClaims(claims));
|
||||||
|
parts.extensions.insert(OidcUserInfo(user_claims));
|
||||||
parts.extensions.insert(OidcAccessToken(
|
parts.extensions.insert(OidcAccessToken(
|
||||||
authenticated_session.access_token.secret().to_string(),
|
authenticated_session.access_token.secret().to_string(),
|
||||||
));
|
));
|
||||||
|
|
@ -356,6 +382,19 @@ fn validate_access_token_hash<AC: AdditionalClaims>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn get_user_claims<AC: AdditionalClaims>(
|
||||||
|
client: &OidcClient<AC>,
|
||||||
|
access_token: AccessToken,
|
||||||
|
) -> Result<UserInfoClaims<AC, CoreGenderClaim>, MiddlewareError> {
|
||||||
|
client
|
||||||
|
.client
|
||||||
|
.user_info(access_token, None)
|
||||||
|
.map_err(MiddlewareError::Configuration)?
|
||||||
|
.request_async(&client.http_client)
|
||||||
|
.await
|
||||||
|
.map_err(|e| e.into())
|
||||||
|
}
|
||||||
|
|
||||||
async fn try_refresh_token<AC: AdditionalClaims>(
|
async fn try_refresh_token<AC: AdditionalClaims>(
|
||||||
client: &OidcClient<AC>,
|
client: &OidcClient<AC>,
|
||||||
refresh_token: &RefreshToken,
|
refresh_token: &RefreshToken,
|
||||||
|
|
@ -363,6 +402,7 @@ async fn try_refresh_token<AC: AdditionalClaims>(
|
||||||
) -> Result<
|
) -> Result<
|
||||||
Option<(
|
Option<(
|
||||||
IdTokenClaims<AC, CoreGenderClaim>,
|
IdTokenClaims<AC, CoreGenderClaim>,
|
||||||
|
UserInfoClaims<AC, CoreGenderClaim>,
|
||||||
AuthenticatedSession<AC>,
|
AuthenticatedSession<AC>,
|
||||||
Option<RefreshToken>,
|
Option<RefreshToken>,
|
||||||
)>,
|
)>,
|
||||||
|
|
@ -380,13 +420,13 @@ async fn try_refresh_token<AC: AdditionalClaims>(
|
||||||
let id_token = token_response
|
let id_token = token_response
|
||||||
.id_token()
|
.id_token()
|
||||||
.ok_or(MiddlewareError::IdTokenMissing)?;
|
.ok_or(MiddlewareError::IdTokenMissing)?;
|
||||||
let id_token_verifier = client.client.id_token_verifier();
|
let id_token_verifier = client
|
||||||
let claims = id_token.claims(&id_token_verifier, |claims_nonce: Option<&Nonce>| {
|
.client
|
||||||
match claims_nonce {
|
.id_token_verifier()
|
||||||
Some(_) => nonce.verify(claims_nonce),
|
.set_other_audience_verifier_fn(|audience|
|
||||||
None => Ok(()),
|
// Return false (reject) if audience is in list of untrusted audiences
|
||||||
}
|
!client.untrusted_audiences.contains(audience));
|
||||||
})?;
|
let claims = id_token.claims(&id_token_verifier, nonce)?;
|
||||||
|
|
||||||
validate_access_token_hash(
|
validate_access_token_hash(
|
||||||
id_token,
|
id_token,
|
||||||
|
|
@ -400,8 +440,12 @@ async fn try_refresh_token<AC: AdditionalClaims>(
|
||||||
access_token: token_response.access_token().clone(),
|
access_token: token_response.access_token().clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let user_claims =
|
||||||
|
get_user_claims(client, authenticated_session.access_token.clone()).await?;
|
||||||
|
|
||||||
Ok(Some((
|
Ok(Some((
|
||||||
claims.clone(),
|
claims.clone(),
|
||||||
|
user_claims,
|
||||||
authenticated_session,
|
authenticated_session,
|
||||||
token_response.refresh_token().cloned(),
|
token_response.refresh_token().cloned(),
|
||||||
)))
|
)))
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue