From 861cb70cee01699d3094e0169e325cdd9f20ff97 Mon Sep 17 00:00:00 2001 From: pfzetto Date: Thu, 6 Nov 2025 18:56:33 +0100 Subject: [PATCH] fix: #32 use OriginalUri for redirect_url --- Cargo.toml | 2 +- src/error.rs | 3 +++ src/middleware.rs | 17 +++++++++++++++-- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 52e1353..29a0ba6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ keywords = [ "axum", "oidc", "openidconnect", "authentication" ] [dependencies] thiserror = "2.0" axum-core = "0.5" -axum = { version = "0.8", default-features = false, features = [ "query" ] } +axum = { version = "0.8", default-features = false, features = [ "query", "original-uri" ] } tower-service = "0.3" tower-layer = "0.3" tower-sessions = { version = "0.14", default-features = false, features = [ "axum-core" ] } diff --git a/src/error.rs b/src/error.rs index 1bd47d5..537d41a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -70,6 +70,9 @@ pub enum MiddlewareError { #[error("auth middleware not found")] AuthMiddlewareNotFound, + + #[error("original url not found")] + OriginalUrlNotFound, } #[derive(Debug, Error)] diff --git a/src/middleware.rs b/src/middleware.rs index 0eddfa4..5d0deab 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -3,7 +3,10 @@ use std::{ task::{Context, Poll}, }; -use axum::response::{IntoResponse, Redirect}; +use axum::{ + extract::OriginalUri, + response::{IntoResponse, Redirect}, +}; use axum_core::response::Response; use futures_util::future::BoxFuture; use http::{request::Parts, Request}; @@ -115,6 +118,16 @@ where .get::() .ok_or(MiddlewareError::SessionNotFound)?; + let redirect_url = parts + .extensions + .get::() + .ok_or(MiddlewareError::OriginalUrlNotFound)?; + + let redirect_url = if let Some(query) = redirect_url.query() { + redirect_url.path().to_string() + "?" + query + } else { + redirect_url.path().to_string() + }; // generate a login url and redirect the user to it let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); @@ -143,7 +156,7 @@ where pkce_verifier, authenticated: None, refresh_token: None, - redirect_url: parts.uri.to_string().into(), + redirect_url: redirect_url.into(), }; session.insert(SESSION_KEY, oidc_session).await?;