use std::{ marker::PhantomData, task::{Context, Poll}, }; use axum::{ extract::OriginalUri, response::{IntoResponse, Redirect}, }; use axum_core::response::Response; use futures_util::future::BoxFuture; use http::{request::Parts, Request}; use tower_layer::Layer; use tower_service::Service; use tower_sessions::Session; use openidconnect::{ core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey}, AccessToken, AccessTokenHash, AuthenticationContextClass, CsrfToken, IdTokenClaims, IdTokenVerifier, Nonce, NonceVerifier, OAuth2TokenResponse, PkceCodeChallenge, RefreshToken, RequestTokenError::ServerResponse, Scope, TokenResponse, }; use crate::{ error::MiddlewareError, extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}, AdditionalClaims, AuthenticatedSession, BoxError, ClearSessionFlag, IdToken, OidcClient, OidcSession, SESSION_KEY, }; /// Layer for the [`OidcLoginMiddleware`]. #[derive(Clone, Default)] pub struct OidcLoginLayer where AC: AdditionalClaims, { additional: PhantomData, } impl OidcLoginLayer { pub fn new() -> Self { Self { additional: PhantomData, } } } impl Layer for OidcLoginLayer where AC: AdditionalClaims, { type Service = OidcLoginMiddleware; fn layer(&self, inner: I) -> Self::Service { OidcLoginMiddleware { inner, additional: PhantomData, } } } /// This middleware forces the user to be authenticated and redirects the user to the OpenID Connect /// Issuer to authenticate. This Middleware needs to be loaded afer [`OidcAuthMiddleware`]. #[derive(Clone)] pub struct OidcLoginMiddleware where AC: AdditionalClaims, { inner: I, additional: PhantomData, } impl Service> for OidcLoginMiddleware where I: Service, Response = Response> + Send + 'static + Clone, I::Error: Send + Into, I::Future: Send + 'static, AC: AdditionalClaims, B: Send + 'static, { type Response = I::Response; type Error = MiddlewareError; type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner .poll_ready(cx) .map_err(|e| MiddlewareError::NextMiddleware(e.into())) } fn call(&mut self, request: Request) -> Self::Future { let inner = self.inner.clone(); let mut inner = std::mem::replace(&mut self.inner, inner); if request.extensions().get::().is_some() { // the OidcAuthMiddleware had a valid id token Box::pin(async move { let response: Response = inner .call(request) .await .map_err(|e| MiddlewareError::NextMiddleware(e.into()))?; Ok(response) }) } else { // no valid id token or refresh token was found and the user has to login Box::pin(async move { let (parts, _) = request.into_parts(); let oidcclient: OidcClient = parts .extensions .get() .cloned() .ok_or(MiddlewareError::AuthMiddlewareNotFound)?; let session = parts .extensions .get::() .ok_or(MiddlewareError::SessionNotFound)?; let redirect_url = parts .extensions .get::() .ok_or(MiddlewareError::OriginalUrlNotFound)?; let redirect_url = if let Some(query) = redirect_url.query() { redirect_url.path().to_string() + "?" + query } else { redirect_url.path().to_string() }; // generate a login url and redirect the user to it let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); let (auth_url, csrf_token, nonce) = { let mut auth = oidcclient.client.authorize_url( CoreAuthenticationFlow::AuthorizationCode, CsrfToken::new_random, Nonce::new_random, ); for scope in oidcclient.scopes.iter() { auth = auth.add_scope(Scope::new(scope.to_string())); } if let Some(acr) = oidcclient.auth_context_class { auth = auth .add_auth_context_value(AuthenticationContextClass::new(acr.into())); } auth.set_pkce_challenge(pkce_challenge).url() }; let oidc_session = OidcSession:: { nonce, csrf_token, pkce_verifier, authenticated: None, refresh_token: None, redirect_url: redirect_url.into(), }; session.insert(SESSION_KEY, oidc_session).await?; Ok(Redirect::to(auth_url.as_str()).into_response()) }) } } } /// Layer for the [`OidcAuthMiddleware`]. #[derive(Clone)] pub struct OidcAuthLayer where AC: AdditionalClaims, { client: OidcClient, } impl OidcAuthLayer { pub fn new(client: OidcClient) -> Self { Self { client } } } impl From> for OidcAuthLayer { fn from(value: OidcClient) -> Self { Self::new(value) } } impl Layer for OidcAuthLayer where AC: AdditionalClaims, { type Service = OidcAuthMiddleware; fn layer(&self, inner: I) -> Self::Service { OidcAuthMiddleware { inner, client: self.client.clone(), } } } /// This middleware checks if the cached session is valid and injects the Claims, the AccessToken /// and the OidcClient in the request. This middleware needs to be loaded for every handler that is /// using on of the Extractors. This middleware **doesn't force a user to be /// authenticated**. #[derive(Clone)] pub struct OidcAuthMiddleware where AC: AdditionalClaims, { inner: I, client: OidcClient, } impl Service> for OidcAuthMiddleware where I: Service> + Send + 'static + Clone, I::Response: IntoResponse + Send, I::Error: Send + Into, I::Future: Send + 'static, AC: AdditionalClaims, B: Send + 'static, { type Response = Response; type Error = MiddlewareError; type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner .poll_ready(cx) .map_err(|e| MiddlewareError::NextMiddleware(e.into())) } fn call(&mut self, request: Request) -> Self::Future { let inner = self.inner.clone(); let mut inner = std::mem::replace(&mut self.inner, inner); let oidcclient = self.client.clone(); Box::pin(async move { let (mut parts, body) = request.into_parts(); let session = parts .extensions .get::() .ok_or(MiddlewareError::SessionNotFound)? .clone(); let mut login_session: Option> = session .get(SESSION_KEY) .await .map_err(MiddlewareError::from)?; if let Some(login_session) = &mut login_session { let id_token_claims = login_session.authenticated.as_ref().and_then(|session| { session .id_token .claims(&oidcclient.client.id_token_verifier(), &login_session.nonce) .ok() .cloned() .map(|claims| (session, claims)) }); if let Some((session, claims)) = id_token_claims { // stored id token is valid and can be used insert_extensions(&mut parts, claims.clone(), &oidcclient, session); } else if let Some(refresh_token) = login_session.refresh_token.as_ref() { // session is expired but can be refreshed using the refresh_token if let Some((claims, authenticated_session, refresh_token)) = try_refresh_token(&oidcclient, refresh_token, &login_session.nonce).await? { insert_extensions(&mut parts, claims, &oidcclient, &authenticated_session); login_session.authenticated = Some(authenticated_session); if let Some(refresh_token) = refresh_token { login_session.refresh_token = Some(refresh_token); } }; // save refreshed session or delete it when the token couldn't be refreshed let session = parts .extensions .get::() .ok_or(MiddlewareError::SessionNotFound)?; session.insert(SESSION_KEY, login_session).await?; } } parts.extensions.insert(oidcclient); let request = Request::from_parts(parts, body); let response: Response = inner .call(request) .await .map_err(|e| MiddlewareError::NextMiddleware(e.into()))? .into_response(); let has_logout_ext = response.extensions().get::().is_some(); if let (true, Some(mut login_session)) = (has_logout_ext, login_session) { login_session.authenticated = None; login_session.refresh_token = None; session.insert(SESSION_KEY, login_session).await?; } Ok(response) }) } } /// insert all extensions that are used by the extractors fn insert_extensions( parts: &mut Parts, claims: IdTokenClaims, client: &OidcClient, authenticated_session: &AuthenticatedSession, ) { parts.extensions.insert(OidcClaims(claims)); parts.extensions.insert(OidcAccessToken( authenticated_session.access_token.secret().to_string(), )); let rp_initiated_logout = client .end_session_endpoint .as_ref() .map(|end_session_endpoint| OidcRpInitiatedLogout { end_session_endpoint: end_session_endpoint.clone(), id_token_hint: authenticated_session.id_token.to_string().into(), client_id: client.client_id.clone(), post_logout_redirect_uri: None, state: None, }); parts.extensions.insert(rp_initiated_logout); } /// Verify the access token hash to ensure that the access token hasn't been substituted for /// another user's. /// Returns `Ok` when access token is valid fn validate_access_token_hash( id_token: &IdToken, id_token_verifier: IdTokenVerifier, access_token: &AccessToken, claims: &IdTokenClaims, ) -> Result<(), MiddlewareError> { if let Some(expected_access_token_hash) = claims.access_token_hash() { let actual_access_token_hash = AccessTokenHash::from_token( access_token, id_token.signing_alg()?, id_token.signing_key(&id_token_verifier)?, )?; if actual_access_token_hash == *expected_access_token_hash { Ok(()) } else { Err(MiddlewareError::AccessTokenHashInvalid) } } else { Ok(()) } } async fn try_refresh_token( client: &OidcClient, refresh_token: &RefreshToken, nonce: &Nonce, ) -> Result< Option<( IdTokenClaims, AuthenticatedSession, Option, )>, MiddlewareError, > { let mut refresh_request = client.client.exchange_refresh_token(refresh_token)?; for scope in client.scopes.iter() { refresh_request = refresh_request.add_scope(Scope::new(scope.to_string())); } match refresh_request.request_async(&client.http_client).await { Ok(token_response) => { // Extract the ID token claims after verifying its authenticity and nonce. let id_token = token_response .id_token() .ok_or(MiddlewareError::IdTokenMissing)?; let id_token_verifier = client.client.id_token_verifier(); let claims = id_token.claims(&id_token_verifier, |claims_nonce: Option<&Nonce>| { match claims_nonce { Some(_) => nonce.verify(claims_nonce), None => Ok(()), } })?; validate_access_token_hash( id_token, id_token_verifier, token_response.access_token(), claims, )?; let authenticated_session = AuthenticatedSession { id_token: id_token.clone(), access_token: token_response.access_token().clone(), }; Ok(Some(( claims.clone(), authenticated_session, token_response.refresh_token().cloned(), ))) } Err(ServerResponse(e)) if *e.error() == CoreErrorResponseType::InvalidGrant => { // Refresh failed, refresh_token most likely expired or // invalid, the session can be considered lost Ok(None) } Err(err) => Err(err.into()), } }