From aa05cf6bdec85a12c14071882692c71902bb5de1 Mon Sep 17 00:00:00 2001 From: Paul Z Date: Fri, 3 Nov 2023 19:42:54 +0100 Subject: [PATCH] first implementation --- Cargo.toml | 14 +- README.md | 73 +++++++++- src/error.rs | 100 ++++++++++++++ src/extractor.rs | 50 +++++++ src/lib.rs | 125 +++++++++++++++++ src/middleware.rs | 334 ++++++++++++++++++++++++++++++++++++++++++++++ src/util.rs | 32 +++++ 7 files changed, 720 insertions(+), 8 deletions(-) create mode 100644 src/error.rs create mode 100644 src/extractor.rs create mode 100644 src/middleware.rs create mode 100644 src/util.rs diff --git a/Cargo.toml b/Cargo.toml index 837c3dc..d837096 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "axum-oidc" -description = "A OpenID Connect Client and Bearer Token Libary for axum" +description = "A OpenID Connect Client Libary for axum" version = "0.0.0" edition = "2021" authors = [ "Paul Z " ] @@ -12,3 +12,15 @@ keywords = [ "axum", "oidc", "openidconnect", "authentication" ] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +thiserror = "1.0.50" +axum-core = "0.3" +axum = { version = "0.6", default-features = false, features = [ "query" ] } +tower-service = "0.3.2" +tower-layer = "0.3.2" +tower-sessions = { version = "0.4", default-features = false, features = [ "axum-core" ] } +http = "0.2" +async-trait = "0.1" +openidconnect = "3.4" +serde = "1.0" +futures-util = "0.3" +reqwest = { version = "0.11", default-features = false } diff --git a/README.md b/README.md index ca293c6..1c11fcc 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,72 @@ -**This crate is still under construction** +This Library allows using [OpenID Connect](https://openid.net/developers/how-connect-works/) with [axum](https://github.com/tokio-rs/axum). +It authenticates the user with the OpenID Conenct Issuer and provides Extractors. -This Library allows using [OpenID Connect](https://openid.net/developers/how-connect-works/) with [axum](https://github.com/tokio-rs/axum). It provides two modes, described below. +# Usage +The `OidcAuthLayer` must be loaded on any handler that might use the extractors. +The user won't be automatically logged in using this layer. +If a valid session is found, the extractors will return the correct value and fail otherwise. -# Operating Modes -## Client Mode -In Client mode, the user visits the axum server with a web browser. The user gets redirected to and authenticated with the Issuer. +The `OidcLoginLayer` should be loaded on any handler on which the user is supposed to be authenticated. +The User will be redirected to the OpenId Conect Issuer to authenticate. +The extractors will always return a value. -## Token Mode -In Token mode, the another system is using the access token of the user to authenticate against the axum server. +The `OidcClaims`-extractor can be used to get the OpenId Conenct Claims. +The `OidcAccessToken`-extractor can be used to get the OpenId Connect Access Token. + +Your OIDC-Client must be allowed to redirect to **every** subpath of your application base url. + +```rust +#[tokio::main] +async fn main() { + + let session_store = MemoryStore::default(); + let session_service = ServiceBuilder::new() + .layer(HandleErrorLayer::new(|_: BoxError| async { + StatusCode::BAD_REQUEST + })) + .layer(SessionManagerLayer::new(session_store).with_same_site(SameSite::Lax)); + + let oidc_login_service = ServiceBuilder::new() + .layer(HandleErrorLayer::new(|e: MiddlewareError| async { + e.into_response() + })) + .layer(OidcLoginLayer::::new()); + + let oidc_auth_service = ServiceBuilder::new() + .layer(HandleErrorLayer::new(|e: MiddlewareError| async { + e.into_response() + })) + .layer( + OidcAuthLayer::::discover_client( + Uri::from_static("https://example.com"), + "".to_string(), + "".to_string(), + "".to_owned(), + vec![], + ).await.unwrap(), + ); + + let app = Router::new() + .route("/", get(|| async { "Hello, authenticated World!" })) + .layer(oidc_login_service) + .layer(oidc_auth_service) + .layer(session_service); + + axum::Server::bind(&"[::]:8080".parse().unwrap()) + .serve(app.into_make_service()) + .await + .unwrap(); +} +``` + +# Example Projects +Here is a place for projects that are using this library. +- [zettoIT ARS - AudienceResponseSystem](https://git2.zettoit.eu/zettoit/ars) (by me) + +# Contributing +I'm happy about any contribution in any form. +Feel free to submit feature requests and bug reports using a GitHub Issue. +PR's are also appreciated. # License This Library is licensed under [LGPLv3](https://www.gnu.org/licenses/lgpl-3.0.en.html). diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..c91ea66 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,100 @@ +use axum_core::{response::IntoResponse, BoxError}; +use http::{ + uri::{InvalidUri, InvalidUriParts}, + StatusCode, +}; +use openidconnect::{core::CoreErrorResponseType, StandardErrorResponse}; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum ExtractorError { + #[error("unauthorized")] + Unauthorized, +} + +#[derive(Debug, Error)] +pub enum MiddlewareError { + #[error("access token hash invalid")] + AccessTokenHashInvalid, + + #[error("csrf token invalid")] + CsrfTokenInvalid, + + #[error("id token missing")] + IdTokenMissing, + + #[error("signing: {0:?}")] + Signing(#[from] openidconnect::SigningError), + + #[error("claims verification: {0:?}")] + ClaimsVerification(#[from] openidconnect::ClaimsVerificationError), + + #[error("url parsing: {0:?}")] + UrlParsing(#[from] openidconnect::url::ParseError), + + #[error("uri parsing: {0:?}")] + UriParsing(#[from] InvalidUri), + + #[error("uri parts parsing: {0:?}")] + UriPartsParsing(#[from] InvalidUriParts), + + #[error("request token: {0:?}")] + RequestToken( + #[from] + openidconnect::RequestTokenError< + openidconnect::reqwest::Error, + StandardErrorResponse, + >, + ), + + #[error("session error: {0:?}")] + Session(#[from] tower_sessions::session::Error), + + #[error("session not found")] + SessionNotFound, + + #[error("next middleware")] + NextMiddleware(#[from] BoxError), + + #[error("auth middleware not found")] + AuthMiddlewareNotFound, +} + +#[derive(Debug, Error)] +pub enum Error { + #[error("url parsing: {0:?}")] + UrlParsing(#[from] openidconnect::url::ParseError), + + #[error("discovery: {0:?}")] + Discovery(#[from] openidconnect::DiscoveryError>), + + #[error("extractor: {0:?}")] + Extractor(#[from] ExtractorError), + + #[error("extractor: {0:?}")] + Middleware(#[from] MiddlewareError), +} + +impl IntoResponse for ExtractorError { + fn into_response(self) -> axum_core::response::Response { + (StatusCode::UNAUTHORIZED, "unauthorized").into_response() + } +} + +impl IntoResponse for Error { + fn into_response(self) -> axum_core::response::Response { + dbg!(&self); + match self { + _ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(), + } + } +} + +impl IntoResponse for MiddlewareError { + fn into_response(self) -> axum_core::response::Response { + dbg!(&self); + match self { + _ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(), + } + } +} diff --git a/src/extractor.rs b/src/extractor.rs new file mode 100644 index 0000000..5e33b2a --- /dev/null +++ b/src/extractor.rs @@ -0,0 +1,50 @@ +use crate::{error::ExtractorError, AdditionalClaims}; +use async_trait::async_trait; +use axum_core::extract::FromRequestParts; +use http::request::Parts; +use openidconnect::{core::CoreGenderClaim, IdTokenClaims}; + +/// Extractor for the OpenID Connect Claims. +/// +/// This Extractor will only return the Claims when the cached session is valid and [crate::middleware::OidcAuthMiddleware] is loaded. +#[derive(Clone)] +pub struct OidcClaims(pub IdTokenClaims); + +#[async_trait] +impl FromRequestParts for OidcClaims +where + S: Send + Sync, + AC: AdditionalClaims, +{ + type Rejection = ExtractorError; + + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { + parts + .extensions + .get::() + .cloned() + .ok_or(ExtractorError::Unauthorized) + } +} + +/// Extractor for the OpenID Connect Access Token. +/// +/// This Extractor will only return the Access Token when the cached session is valid and [crate::middleware::OidcAuthMiddleware] is loaded. +#[derive(Clone)] +pub struct OidcAccessToken(pub String); + +#[async_trait] +impl FromRequestParts for OidcAccessToken +where + S: Send + Sync, +{ + type Rejection = ExtractorError; + + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { + parts + .extensions + .get::() + .cloned() + .ok_or(ExtractorError::Unauthorized) + } +} diff --git a/src/lib.rs b/src/lib.rs index e69de29..609060f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -0,0 +1,125 @@ +#![doc = include_str!("../README.md")] + +use crate::error::Error; +use http::Uri; +use openidconnect::{ + core::{ + CoreAuthDisplay, CoreAuthPrompt, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey, + CoreJsonWebKeyType, CoreJsonWebKeyUse, CoreJweContentEncryptionAlgorithm, + CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreRevocableToken, + CoreRevocationErrorResponse, CoreTokenIntrospectionResponse, CoreTokenType, + }, + reqwest::async_http_client, + ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, IdTokenFields, IssuerUrl, Nonce, + PkceCodeVerifier, StandardErrorResponse, StandardTokenResponse, +}; +use serde::{Deserialize, Serialize}; + +pub mod error; +mod extractor; +mod middleware; +mod util; + +pub use extractor::{OidcAccessToken, OidcClaims}; +pub use middleware::{OidcAuthLayer, OidcAuthMiddleware, OidcLoginLayer, OidcLoginMiddleware}; + +const SESSION_KEY: &str = "axum-oidc"; + +pub trait AdditionalClaims: openidconnect::AdditionalClaims + Clone + Sync + Send {} + +type OidcTokenResponse = StandardTokenResponse< + IdTokenFields< + AC, + EmptyExtraTokenFields, + CoreGenderClaim, + CoreJweContentEncryptionAlgorithm, + CoreJwsSigningAlgorithm, + CoreJsonWebKeyType, + >, + CoreTokenType, +>; + +pub type IdToken = openidconnect::IdToken< + AZ, + CoreGenderClaim, + CoreJweContentEncryptionAlgorithm, + CoreJwsSigningAlgorithm, + CoreJsonWebKeyType, +>; + +type Client = openidconnect::Client< + AC, + CoreAuthDisplay, + CoreGenderClaim, + CoreJweContentEncryptionAlgorithm, + CoreJwsSigningAlgorithm, + CoreJsonWebKeyType, + CoreJsonWebKeyUse, + CoreJsonWebKey, + CoreAuthPrompt, + StandardErrorResponse, + OidcTokenResponse, + CoreTokenType, + CoreTokenIntrospectionResponse, + CoreRevocableToken, + CoreRevocationErrorResponse, +>; + +pub type BoxError = Box; + +/// OpenID Connect Client +#[derive(Clone)] +pub struct OidcClient { + scopes: Vec, + client: Client, + application_base_url: Uri, +} + +impl OidcClient { + pub async fn discover_new( + application_base_url: Uri, + issuer: String, + client_id: String, + client_secret: Option, + scopes: Vec, + ) -> Result { + let provider_metadata = + CoreProviderMetadata::discover_async(IssuerUrl::new(issuer)?, async_http_client) + .await?; + let client = Client::from_provider_metadata( + provider_metadata, + ClientId::new(client_id), + client_secret.map(|x| ClientSecret::new(x)), + ); + Ok(Self { + scopes, + client, + application_base_url, + }) + } +} + +/// an empty struct to be used as the default type for the additional claims generic +#[derive(Deserialize, Serialize, Debug, Clone, Copy, Default)] +pub struct EmptyAdditionalClaims {} +impl AdditionalClaims for EmptyAdditionalClaims {} +impl openidconnect::AdditionalClaims for EmptyAdditionalClaims {} + +/// response data of the openid issuer after login +#[derive(Debug, Deserialize)] +struct OidcQuery { + code: String, + state: String, + #[allow(dead_code)] + session_state: String, +} + +/// oidc session +#[derive(Serialize, Deserialize, Debug)] +struct OidcSession { + nonce: Nonce, + csrf_token: CsrfToken, + pkce_verifier: PkceCodeVerifier, + id_token: Option, + access_token: Option, +} diff --git a/src/middleware.rs b/src/middleware.rs new file mode 100644 index 0000000..70c9054 --- /dev/null +++ b/src/middleware.rs @@ -0,0 +1,334 @@ +use std::{ + marker::PhantomData, + str::FromStr, + task::{Context, Poll}, +}; + +use axum::{ + extract::Query, + response::{IntoResponse, Redirect}, +}; +use axum_core::{extract::FromRequestParts, response::Response}; +use futures_util::future::BoxFuture; +use http::{Request, Uri}; +use tower_layer::Layer; +use tower_service::Service; +use tower_sessions::Session; + +use openidconnect::{ + core::CoreAuthenticationFlow, reqwest::async_http_client, AccessTokenHash, AuthorizationCode, + CsrfToken, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope, + TokenResponse, +}; + +use crate::{ + error::{Error, MiddlewareError}, + extractor::{OidcAccessToken, OidcClaims}, + util::strip_oidc_from_path, + AdditionalClaims, BoxError, IdToken, OidcClient, OidcQuery, OidcSession, SESSION_KEY, +}; + +/// Layer for the [OidcLoginMiddleware]. +#[derive(Clone, Default)] +pub struct OidcLoginLayer +where + AC: AdditionalClaims, +{ + additional: PhantomData, +} + +impl OidcLoginLayer { + pub fn new() -> Self { + Self { + additional: PhantomData, + } + } +} + +impl Layer for OidcLoginLayer +where + AC: AdditionalClaims, +{ + type Service = OidcLoginMiddleware; + + fn layer(&self, inner: I) -> Self::Service { + OidcLoginMiddleware { + inner, + additional: PhantomData, + } + } +} + +/// This middleware forces the user to be authenticated and redirects the user to the OpenID Connect +/// Issuer to authenticate. This Middleware needs to be loaded afer [OidcAuthMiddleware]. +#[derive(Clone)] +pub struct OidcLoginMiddleware +where + AC: AdditionalClaims, +{ + inner: I, + additional: PhantomData, +} + +impl Service> for OidcLoginMiddleware +where + I: Service, Response = Response> + Send + 'static + Clone, + I::Error: Send + Into, + I::Future: Send + 'static, + AC: AdditionalClaims, + B: Send + 'static, +{ + type Response = I::Response; + type Error = MiddlewareError; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner + .poll_ready(cx) + .map_err(|e| MiddlewareError::NextMiddleware(e.into())) + } + + fn call(&mut self, request: Request) -> Self::Future { + let inner = self.inner.clone(); + let mut inner = std::mem::replace(&mut self.inner, inner); + + if request.extensions().get::().is_some() { + Box::pin(async move { + let response: Response = inner + .call(request) + .await + .map_err(|e| MiddlewareError::NextMiddleware(e.into()))?; + return Ok(response); + }) + } else { + Box::pin(async move { + let (mut parts, _) = request.into_parts(); + + let mut oidcclient: OidcClient = parts + .extensions + .get() + .cloned() + .ok_or(MiddlewareError::AuthMiddlewareNotFound)?; + + let query = Query::::from_request_parts(&mut parts, &()) + .await + .ok(); + + let session = parts + .extensions + .get::() + .ok_or(MiddlewareError::SessionNotFound)?; + let login_session: Option = + session.get(SESSION_KEY).map_err(MiddlewareError::from)?; + + let handler_uri = + strip_oidc_from_path(oidcclient.application_base_url.clone(), &parts.uri)?; + + oidcclient.client = oidcclient + .client + .set_redirect_uri(RedirectUrl::new(handler_uri.to_string())?); + + if let (Some(mut login_session), Some(query)) = (login_session, query) { + if login_session.csrf_token.secret() != &query.state { + return Err(MiddlewareError::CsrfTokenInvalid); + } + + let token_response = oidcclient + .client + .exchange_code(AuthorizationCode::new(query.code.to_string())) + // Set the PKCE code verifier. + .set_pkce_verifier(PkceCodeVerifier::new( + login_session.pkce_verifier.secret().to_string(), + )) + .request_async(async_http_client) + .await?; + + // Extract the ID token claims after verifying its authenticity and nonce. + let id_token = token_response + .id_token() + .ok_or(MiddlewareError::IdTokenMissing)?; + let claims = id_token + .claims(&oidcclient.client.id_token_verifier(), &login_session.nonce)?; + + // Verify the access token hash to ensure that the access token hasn't been substituted for + // another user's. + if let Some(expected_access_token_hash) = claims.access_token_hash() { + let actual_access_token_hash = AccessTokenHash::from_token( + token_response.access_token(), + &id_token.signing_alg()?, + )?; + if actual_access_token_hash != *expected_access_token_hash { + return Err(MiddlewareError::AccessTokenHashInvalid); + } + } + + login_session.id_token = Some(id_token.to_string()); + login_session.access_token = + Some(token_response.access_token().secret().to_string()); + + session.insert(SESSION_KEY, login_session).unwrap(); + + Ok(Redirect::temporary(&handler_uri.to_string()).into_response()) + } else { + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + let (auth_url, csrf_token, nonce) = { + let mut auth = oidcclient.client.authorize_url( + CoreAuthenticationFlow::AuthorizationCode, + CsrfToken::new_random, + Nonce::new_random, + ); + + for scope in oidcclient.scopes.iter() { + auth = auth.add_scope(Scope::new(scope.to_string())); + } + + auth.set_pkce_challenge(pkce_challenge).url() + }; + + let oidc_session = OidcSession { + nonce, + csrf_token, + pkce_verifier, + id_token: None, + access_token: None, + }; + + session.insert(SESSION_KEY, oidc_session).unwrap(); + + Ok(Redirect::temporary(auth_url.as_str()).into_response()) + } + }) + } + } +} + +/// Layer for the [OidcAuthMiddleware]. +#[derive(Clone)] +pub struct OidcAuthLayer +where + AC: AdditionalClaims, +{ + client: OidcClient, +} + +impl OidcAuthLayer { + pub fn new(client: OidcClient) -> Self { + Self { client } + } + + pub async fn discover_client( + application_base_url: Uri, + issuer: String, + client_id: String, + client_secret: Option, + scopes: Vec, + ) -> Result { + Ok(Self { + client: OidcClient::::discover_new( + application_base_url, + issuer, + client_id, + client_secret, + scopes, + ) + .await?, + }) + } +} + +impl Layer for OidcAuthLayer +where + AC: AdditionalClaims, +{ + type Service = OidcAuthMiddleware; + + fn layer(&self, inner: I) -> Self::Service { + OidcAuthMiddleware { + inner, + client: self.client.clone(), + } + } +} + +/// This middleware checks if the cached session is valid and injects the Claims, the AccessToken +/// and the OidcClient in the request. This middleware needs to be loaded for every handler that is +/// using on of the Extractors. This middleware **doesn't force a user to be +/// authenticated**. +#[derive(Clone)] +pub struct OidcAuthMiddleware +where + AC: AdditionalClaims, +{ + inner: I, + client: OidcClient, +} + +impl Service> for OidcAuthMiddleware +where + I: Service> + Send + 'static + Clone, + I::Response: IntoResponse + Send, + I::Error: Send + Into, + I::Future: Send + 'static, + AC: AdditionalClaims, + B: Send + 'static, +{ + type Response = Response; + type Error = MiddlewareError; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner + .poll_ready(cx) + .map_err(|e| MiddlewareError::NextMiddleware(e.into())) + } + + fn call(&mut self, request: Request) -> Self::Future { + let inner = self.inner.clone(); + let mut inner = std::mem::replace(&mut self.inner, inner); + let mut oidcclient = self.client.clone(); + Box::pin(async move { + let (mut parts, body) = request.into_parts(); + + let session = parts + .extensions + .get::() + .ok_or(MiddlewareError::SessionNotFound)?; + let login_session: Option = + session.get(SESSION_KEY).map_err(MiddlewareError::from)?; + + let handler_uri = + strip_oidc_from_path(oidcclient.application_base_url.clone(), &parts.uri)?; + + oidcclient.client = oidcclient + .client + .set_redirect_uri(RedirectUrl::new(handler_uri.to_string())?); + + if let Some(OidcSession { + nonce, + csrf_token: _, + pkce_verifier: _, + id_token: Some(id_token), + access_token, + }) = &login_session + { + let id_token = IdToken::::from_str(&id_token).unwrap(); + if let Ok(claims) = id_token.claims(&oidcclient.client.id_token_verifier(), nonce) { + parts.extensions.insert(OidcClaims(claims.clone())); + parts + .extensions + .insert(OidcAccessToken(access_token.clone().unwrap_or_default())); + } + } + + parts.extensions.insert(oidcclient); + + let request = Request::from_parts(parts, body); + let response: Response = inner + .call(request) + .await + .map_err(|e| MiddlewareError::NextMiddleware(e.into()))? + .into_response(); + return Ok(response); + }) + } +} diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 0000000..e601438 --- /dev/null +++ b/src/util.rs @@ -0,0 +1,32 @@ +use http::{uri::PathAndQuery, Uri}; + +use crate::error::MiddlewareError; + +/// Helper function to remove the OpenID Connect authentication response query attributes from a +/// [`Uri`]. +pub fn strip_oidc_from_path(base_url: Uri, uri: &Uri) -> Result { + let mut base_url = base_url.into_parts(); + + base_url.path_and_query = uri + .path_and_query() + .map(|path_and_query| { + let query = path_and_query + .query() + .and_then(|uri| { + uri.split('&') + .filter(|x| { + !x.starts_with("code") + && !x.starts_with("state") + && !x.starts_with("session_state") + }) + .map(|x| x.to_string()) + .reduce(|acc, x| acc + "&" + &x) + }) + .unwrap_or_default(); + + PathAndQuery::from_maybe_shared(format!("{}?{}", path_and_query.path(), query)) + }) + .transpose()?; + + Ok(Uri::from_parts(base_url)?) +}