first implementation

This commit is contained in:
Paul Zinselmeyer 2023-11-03 19:42:54 +01:00
parent 1b3973064b
commit aa05cf6bde
Signed by: pfzetto
GPG key ID: 4EEF46A5B276E648
7 changed files with 720 additions and 8 deletions

View file

@ -1,6 +1,6 @@
[package] [package]
name = "axum-oidc" name = "axum-oidc"
description = "A OpenID Connect Client and Bearer Token Libary for axum" description = "A OpenID Connect Client Libary for axum"
version = "0.0.0" version = "0.0.0"
edition = "2021" edition = "2021"
authors = [ "Paul Z <info@pfz4.de>" ] authors = [ "Paul Z <info@pfz4.de>" ]
@ -12,3 +12,15 @@ keywords = [ "axum", "oidc", "openidconnect", "authentication" ]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
thiserror = "1.0.50"
axum-core = "0.3"
axum = { version = "0.6", default-features = false, features = [ "query" ] }
tower-service = "0.3.2"
tower-layer = "0.3.2"
tower-sessions = { version = "0.4", default-features = false, features = [ "axum-core" ] }
http = "0.2"
async-trait = "0.1"
openidconnect = "3.4"
serde = "1.0"
futures-util = "0.3"
reqwest = { version = "0.11", default-features = false }

View file

@ -1,13 +1,72 @@
**This crate is still under construction** This Library allows using [OpenID Connect](https://openid.net/developers/how-connect-works/) with [axum](https://github.com/tokio-rs/axum).
It authenticates the user with the OpenID Conenct Issuer and provides Extractors.
This Library allows using [OpenID Connect](https://openid.net/developers/how-connect-works/) with [axum](https://github.com/tokio-rs/axum). It provides two modes, described below. # Usage
The `OidcAuthLayer` must be loaded on any handler that might use the extractors.
The user won't be automatically logged in using this layer.
If a valid session is found, the extractors will return the correct value and fail otherwise.
# Operating Modes The `OidcLoginLayer` should be loaded on any handler on which the user is supposed to be authenticated.
## Client Mode The User will be redirected to the OpenId Conect Issuer to authenticate.
In Client mode, the user visits the axum server with a web browser. The user gets redirected to and authenticated with the Issuer. The extractors will always return a value.
## Token Mode The `OidcClaims`-extractor can be used to get the OpenId Conenct Claims.
In Token mode, the another system is using the access token of the user to authenticate against the axum server. The `OidcAccessToken`-extractor can be used to get the OpenId Connect Access Token.
Your OIDC-Client must be allowed to redirect to **every** subpath of your application base url.
```rust
#[tokio::main]
async fn main() {
let session_store = MemoryStore::default();
let session_service = ServiceBuilder::new()
.layer(HandleErrorLayer::new(|_: BoxError| async {
StatusCode::BAD_REQUEST
}))
.layer(SessionManagerLayer::new(session_store).with_same_site(SameSite::Lax));
let oidc_login_service = ServiceBuilder::new()
.layer(HandleErrorLayer::new(|e: MiddlewareError| async {
e.into_response()
}))
.layer(OidcLoginLayer::<EmptyAdditionalClaims>::new());
let oidc_auth_service = ServiceBuilder::new()
.layer(HandleErrorLayer::new(|e: MiddlewareError| async {
e.into_response()
}))
.layer(
OidcAuthLayer::<EmptyAdditionalClaims>::discover_client(
Uri::from_static("https://example.com"),
"<issuer>".to_string(),
"<client_id>".to_string(),
"<client_secret>".to_owned(),
vec![],
).await.unwrap(),
);
let app = Router::new()
.route("/", get(|| async { "Hello, authenticated World!" }))
.layer(oidc_login_service)
.layer(oidc_auth_service)
.layer(session_service);
axum::Server::bind(&"[::]:8080".parse().unwrap())
.serve(app.into_make_service())
.await
.unwrap();
}
```
# Example Projects
Here is a place for projects that are using this library.
- [zettoIT ARS - AudienceResponseSystem](https://git2.zettoit.eu/zettoit/ars) (by me)
# Contributing
I'm happy about any contribution in any form.
Feel free to submit feature requests and bug reports using a GitHub Issue.
PR's are also appreciated.
# License # License
This Library is licensed under [LGPLv3](https://www.gnu.org/licenses/lgpl-3.0.en.html). This Library is licensed under [LGPLv3](https://www.gnu.org/licenses/lgpl-3.0.en.html).

100
src/error.rs Normal file
View file

@ -0,0 +1,100 @@
use axum_core::{response::IntoResponse, BoxError};
use http::{
uri::{InvalidUri, InvalidUriParts},
StatusCode,
};
use openidconnect::{core::CoreErrorResponseType, StandardErrorResponse};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ExtractorError {
#[error("unauthorized")]
Unauthorized,
}
#[derive(Debug, Error)]
pub enum MiddlewareError {
#[error("access token hash invalid")]
AccessTokenHashInvalid,
#[error("csrf token invalid")]
CsrfTokenInvalid,
#[error("id token missing")]
IdTokenMissing,
#[error("signing: {0:?}")]
Signing(#[from] openidconnect::SigningError),
#[error("claims verification: {0:?}")]
ClaimsVerification(#[from] openidconnect::ClaimsVerificationError),
#[error("url parsing: {0:?}")]
UrlParsing(#[from] openidconnect::url::ParseError),
#[error("uri parsing: {0:?}")]
UriParsing(#[from] InvalidUri),
#[error("uri parts parsing: {0:?}")]
UriPartsParsing(#[from] InvalidUriParts),
#[error("request token: {0:?}")]
RequestToken(
#[from]
openidconnect::RequestTokenError<
openidconnect::reqwest::Error<reqwest::Error>,
StandardErrorResponse<CoreErrorResponseType>,
>,
),
#[error("session error: {0:?}")]
Session(#[from] tower_sessions::session::Error),
#[error("session not found")]
SessionNotFound,
#[error("next middleware")]
NextMiddleware(#[from] BoxError),
#[error("auth middleware not found")]
AuthMiddlewareNotFound,
}
#[derive(Debug, Error)]
pub enum Error {
#[error("url parsing: {0:?}")]
UrlParsing(#[from] openidconnect::url::ParseError),
#[error("discovery: {0:?}")]
Discovery(#[from] openidconnect::DiscoveryError<openidconnect::reqwest::Error<reqwest::Error>>),
#[error("extractor: {0:?}")]
Extractor(#[from] ExtractorError),
#[error("extractor: {0:?}")]
Middleware(#[from] MiddlewareError),
}
impl IntoResponse for ExtractorError {
fn into_response(self) -> axum_core::response::Response {
(StatusCode::UNAUTHORIZED, "unauthorized").into_response()
}
}
impl IntoResponse for Error {
fn into_response(self) -> axum_core::response::Response {
dbg!(&self);
match self {
_ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(),
}
}
}
impl IntoResponse for MiddlewareError {
fn into_response(self) -> axum_core::response::Response {
dbg!(&self);
match self {
_ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(),
}
}
}

50
src/extractor.rs Normal file
View file

@ -0,0 +1,50 @@
use crate::{error::ExtractorError, AdditionalClaims};
use async_trait::async_trait;
use axum_core::extract::FromRequestParts;
use http::request::Parts;
use openidconnect::{core::CoreGenderClaim, IdTokenClaims};
/// Extractor for the OpenID Connect Claims.
///
/// This Extractor will only return the Claims when the cached session is valid and [crate::middleware::OidcAuthMiddleware] is loaded.
#[derive(Clone)]
pub struct OidcClaims<AC: AdditionalClaims>(pub IdTokenClaims<AC, CoreGenderClaim>);
#[async_trait]
impl<S, AC> FromRequestParts<S> for OidcClaims<AC>
where
S: Send + Sync,
AC: AdditionalClaims,
{
type Rejection = ExtractorError;
async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<Self>()
.cloned()
.ok_or(ExtractorError::Unauthorized)
}
}
/// Extractor for the OpenID Connect Access Token.
///
/// This Extractor will only return the Access Token when the cached session is valid and [crate::middleware::OidcAuthMiddleware] is loaded.
#[derive(Clone)]
pub struct OidcAccessToken(pub String);
#[async_trait]
impl<S> FromRequestParts<S> for OidcAccessToken
where
S: Send + Sync,
{
type Rejection = ExtractorError;
async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<Self>()
.cloned()
.ok_or(ExtractorError::Unauthorized)
}
}

View file

@ -0,0 +1,125 @@
#![doc = include_str!("../README.md")]
use crate::error::Error;
use http::Uri;
use openidconnect::{
core::{
CoreAuthDisplay, CoreAuthPrompt, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey,
CoreJsonWebKeyType, CoreJsonWebKeyUse, CoreJweContentEncryptionAlgorithm,
CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreRevocableToken,
CoreRevocationErrorResponse, CoreTokenIntrospectionResponse, CoreTokenType,
},
reqwest::async_http_client,
ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, IdTokenFields, IssuerUrl, Nonce,
PkceCodeVerifier, StandardErrorResponse, StandardTokenResponse,
};
use serde::{Deserialize, Serialize};
pub mod error;
mod extractor;
mod middleware;
mod util;
pub use extractor::{OidcAccessToken, OidcClaims};
pub use middleware::{OidcAuthLayer, OidcAuthMiddleware, OidcLoginLayer, OidcLoginMiddleware};
const SESSION_KEY: &str = "axum-oidc";
pub trait AdditionalClaims: openidconnect::AdditionalClaims + Clone + Sync + Send {}
type OidcTokenResponse<AC> = StandardTokenResponse<
IdTokenFields<
AC,
EmptyExtraTokenFields,
CoreGenderClaim,
CoreJweContentEncryptionAlgorithm,
CoreJwsSigningAlgorithm,
CoreJsonWebKeyType,
>,
CoreTokenType,
>;
pub type IdToken<AZ> = openidconnect::IdToken<
AZ,
CoreGenderClaim,
CoreJweContentEncryptionAlgorithm,
CoreJwsSigningAlgorithm,
CoreJsonWebKeyType,
>;
type Client<AC> = openidconnect::Client<
AC,
CoreAuthDisplay,
CoreGenderClaim,
CoreJweContentEncryptionAlgorithm,
CoreJwsSigningAlgorithm,
CoreJsonWebKeyType,
CoreJsonWebKeyUse,
CoreJsonWebKey,
CoreAuthPrompt,
StandardErrorResponse<CoreErrorResponseType>,
OidcTokenResponse<AC>,
CoreTokenType,
CoreTokenIntrospectionResponse,
CoreRevocableToken,
CoreRevocationErrorResponse,
>;
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
/// OpenID Connect Client
#[derive(Clone)]
pub struct OidcClient<AC: AdditionalClaims> {
scopes: Vec<String>,
client: Client<AC>,
application_base_url: Uri,
}
impl<AC: AdditionalClaims> OidcClient<AC> {
pub async fn discover_new(
application_base_url: Uri,
issuer: String,
client_id: String,
client_secret: Option<String>,
scopes: Vec<String>,
) -> Result<Self, Error> {
let provider_metadata =
CoreProviderMetadata::discover_async(IssuerUrl::new(issuer)?, async_http_client)
.await?;
let client = Client::from_provider_metadata(
provider_metadata,
ClientId::new(client_id),
client_secret.map(|x| ClientSecret::new(x)),
);
Ok(Self {
scopes,
client,
application_base_url,
})
}
}
/// an empty struct to be used as the default type for the additional claims generic
#[derive(Deserialize, Serialize, Debug, Clone, Copy, Default)]
pub struct EmptyAdditionalClaims {}
impl AdditionalClaims for EmptyAdditionalClaims {}
impl openidconnect::AdditionalClaims for EmptyAdditionalClaims {}
/// response data of the openid issuer after login
#[derive(Debug, Deserialize)]
struct OidcQuery {
code: String,
state: String,
#[allow(dead_code)]
session_state: String,
}
/// oidc session
#[derive(Serialize, Deserialize, Debug)]
struct OidcSession {
nonce: Nonce,
csrf_token: CsrfToken,
pkce_verifier: PkceCodeVerifier,
id_token: Option<String>,
access_token: Option<String>,
}

334
src/middleware.rs Normal file
View file

@ -0,0 +1,334 @@
use std::{
marker::PhantomData,
str::FromStr,
task::{Context, Poll},
};
use axum::{
extract::Query,
response::{IntoResponse, Redirect},
};
use axum_core::{extract::FromRequestParts, response::Response};
use futures_util::future::BoxFuture;
use http::{Request, Uri};
use tower_layer::Layer;
use tower_service::Service;
use tower_sessions::Session;
use openidconnect::{
core::CoreAuthenticationFlow, reqwest::async_http_client, AccessTokenHash, AuthorizationCode,
CsrfToken, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope,
TokenResponse,
};
use crate::{
error::{Error, MiddlewareError},
extractor::{OidcAccessToken, OidcClaims},
util::strip_oidc_from_path,
AdditionalClaims, BoxError, IdToken, OidcClient, OidcQuery, OidcSession, SESSION_KEY,
};
/// Layer for the [OidcLoginMiddleware].
#[derive(Clone, Default)]
pub struct OidcLoginLayer<AC>
where
AC: AdditionalClaims,
{
additional: PhantomData<AC>,
}
impl<AC: AdditionalClaims> OidcLoginLayer<AC> {
pub fn new() -> Self {
Self {
additional: PhantomData,
}
}
}
impl<I, AC> Layer<I> for OidcLoginLayer<AC>
where
AC: AdditionalClaims,
{
type Service = OidcLoginMiddleware<I, AC>;
fn layer(&self, inner: I) -> Self::Service {
OidcLoginMiddleware {
inner,
additional: PhantomData,
}
}
}
/// This middleware forces the user to be authenticated and redirects the user to the OpenID Connect
/// Issuer to authenticate. This Middleware needs to be loaded afer [OidcAuthMiddleware].
#[derive(Clone)]
pub struct OidcLoginMiddleware<I, AC>
where
AC: AdditionalClaims,
{
inner: I,
additional: PhantomData<AC>,
}
impl<I, AC, B> Service<Request<B>> for OidcLoginMiddleware<I, AC>
where
I: Service<Request<B>, Response = Response> + Send + 'static + Clone,
I::Error: Send + Into<BoxError>,
I::Future: Send + 'static,
AC: AdditionalClaims,
B: Send + 'static,
{
type Response = I::Response;
type Error = MiddlewareError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner
.poll_ready(cx)
.map_err(|e| MiddlewareError::NextMiddleware(e.into()))
}
fn call(&mut self, request: Request<B>) -> Self::Future {
let inner = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, inner);
if request.extensions().get::<OidcAccessToken>().is_some() {
Box::pin(async move {
let response: Response = inner
.call(request)
.await
.map_err(|e| MiddlewareError::NextMiddleware(e.into()))?;
return Ok(response);
})
} else {
Box::pin(async move {
let (mut parts, _) = request.into_parts();
let mut oidcclient: OidcClient<AC> = parts
.extensions
.get()
.cloned()
.ok_or(MiddlewareError::AuthMiddlewareNotFound)?;
let query = Query::<OidcQuery>::from_request_parts(&mut parts, &())
.await
.ok();
let session = parts
.extensions
.get::<Session>()
.ok_or(MiddlewareError::SessionNotFound)?;
let login_session: Option<OidcSession> =
session.get(SESSION_KEY).map_err(MiddlewareError::from)?;
let handler_uri =
strip_oidc_from_path(oidcclient.application_base_url.clone(), &parts.uri)?;
oidcclient.client = oidcclient
.client
.set_redirect_uri(RedirectUrl::new(handler_uri.to_string())?);
if let (Some(mut login_session), Some(query)) = (login_session, query) {
if login_session.csrf_token.secret() != &query.state {
return Err(MiddlewareError::CsrfTokenInvalid);
}
let token_response = oidcclient
.client
.exchange_code(AuthorizationCode::new(query.code.to_string()))
// Set the PKCE code verifier.
.set_pkce_verifier(PkceCodeVerifier::new(
login_session.pkce_verifier.secret().to_string(),
))
.request_async(async_http_client)
.await?;
// Extract the ID token claims after verifying its authenticity and nonce.
let id_token = token_response
.id_token()
.ok_or(MiddlewareError::IdTokenMissing)?;
let claims = id_token
.claims(&oidcclient.client.id_token_verifier(), &login_session.nonce)?;
// Verify the access token hash to ensure that the access token hasn't been substituted for
// another user's.
if let Some(expected_access_token_hash) = claims.access_token_hash() {
let actual_access_token_hash = AccessTokenHash::from_token(
token_response.access_token(),
&id_token.signing_alg()?,
)?;
if actual_access_token_hash != *expected_access_token_hash {
return Err(MiddlewareError::AccessTokenHashInvalid);
}
}
login_session.id_token = Some(id_token.to_string());
login_session.access_token =
Some(token_response.access_token().secret().to_string());
session.insert(SESSION_KEY, login_session).unwrap();
Ok(Redirect::temporary(&handler_uri.to_string()).into_response())
} else {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let (auth_url, csrf_token, nonce) = {
let mut auth = oidcclient.client.authorize_url(
CoreAuthenticationFlow::AuthorizationCode,
CsrfToken::new_random,
Nonce::new_random,
);
for scope in oidcclient.scopes.iter() {
auth = auth.add_scope(Scope::new(scope.to_string()));
}
auth.set_pkce_challenge(pkce_challenge).url()
};
let oidc_session = OidcSession {
nonce,
csrf_token,
pkce_verifier,
id_token: None,
access_token: None,
};
session.insert(SESSION_KEY, oidc_session).unwrap();
Ok(Redirect::temporary(auth_url.as_str()).into_response())
}
})
}
}
}
/// Layer for the [OidcAuthMiddleware].
#[derive(Clone)]
pub struct OidcAuthLayer<AC>
where
AC: AdditionalClaims,
{
client: OidcClient<AC>,
}
impl<AC: AdditionalClaims> OidcAuthLayer<AC> {
pub fn new(client: OidcClient<AC>) -> Self {
Self { client }
}
pub async fn discover_client(
application_base_url: Uri,
issuer: String,
client_id: String,
client_secret: Option<String>,
scopes: Vec<String>,
) -> Result<Self, Error> {
Ok(Self {
client: OidcClient::<AC>::discover_new(
application_base_url,
issuer,
client_id,
client_secret,
scopes,
)
.await?,
})
}
}
impl<I, AC> Layer<I> for OidcAuthLayer<AC>
where
AC: AdditionalClaims,
{
type Service = OidcAuthMiddleware<I, AC>;
fn layer(&self, inner: I) -> Self::Service {
OidcAuthMiddleware {
inner,
client: self.client.clone(),
}
}
}
/// This middleware checks if the cached session is valid and injects the Claims, the AccessToken
/// and the OidcClient in the request. This middleware needs to be loaded for every handler that is
/// using on of the Extractors. This middleware **doesn't force a user to be
/// authenticated**.
#[derive(Clone)]
pub struct OidcAuthMiddleware<I, AC>
where
AC: AdditionalClaims,
{
inner: I,
client: OidcClient<AC>,
}
impl<I, AC, B> Service<Request<B>> for OidcAuthMiddleware<I, AC>
where
I: Service<Request<B>> + Send + 'static + Clone,
I::Response: IntoResponse + Send,
I::Error: Send + Into<BoxError>,
I::Future: Send + 'static,
AC: AdditionalClaims,
B: Send + 'static,
{
type Response = Response;
type Error = MiddlewareError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner
.poll_ready(cx)
.map_err(|e| MiddlewareError::NextMiddleware(e.into()))
}
fn call(&mut self, request: Request<B>) -> Self::Future {
let inner = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, inner);
let mut oidcclient = self.client.clone();
Box::pin(async move {
let (mut parts, body) = request.into_parts();
let session = parts
.extensions
.get::<Session>()
.ok_or(MiddlewareError::SessionNotFound)?;
let login_session: Option<OidcSession> =
session.get(SESSION_KEY).map_err(MiddlewareError::from)?;
let handler_uri =
strip_oidc_from_path(oidcclient.application_base_url.clone(), &parts.uri)?;
oidcclient.client = oidcclient
.client
.set_redirect_uri(RedirectUrl::new(handler_uri.to_string())?);
if let Some(OidcSession {
nonce,
csrf_token: _,
pkce_verifier: _,
id_token: Some(id_token),
access_token,
}) = &login_session
{
let id_token = IdToken::<AC>::from_str(&id_token).unwrap();
if let Ok(claims) = id_token.claims(&oidcclient.client.id_token_verifier(), nonce) {
parts.extensions.insert(OidcClaims(claims.clone()));
parts
.extensions
.insert(OidcAccessToken(access_token.clone().unwrap_or_default()));
}
}
parts.extensions.insert(oidcclient);
let request = Request::from_parts(parts, body);
let response: Response = inner
.call(request)
.await
.map_err(|e| MiddlewareError::NextMiddleware(e.into()))?
.into_response();
return Ok(response);
})
}
}

32
src/util.rs Normal file
View file

@ -0,0 +1,32 @@
use http::{uri::PathAndQuery, Uri};
use crate::error::MiddlewareError;
/// Helper function to remove the OpenID Connect authentication response query attributes from a
/// [`Uri`].
pub fn strip_oidc_from_path(base_url: Uri, uri: &Uri) -> Result<Uri, MiddlewareError> {
let mut base_url = base_url.into_parts();
base_url.path_and_query = uri
.path_and_query()
.map(|path_and_query| {
let query = path_and_query
.query()
.and_then(|uri| {
uri.split('&')
.filter(|x| {
!x.starts_with("code")
&& !x.starts_with("state")
&& !x.starts_with("session_state")
})
.map(|x| x.to_string())
.reduce(|acc, x| acc + "&" + &x)
})
.unwrap_or_default();
PathAndQuery::from_maybe_shared(format!("{}?{}", path_and_query.path(), query))
})
.transpose()?;
Ok(Uri::from_parts(base_url)?)
}