From a07426695a8714e7b0edb301ed4e751ab547896d Mon Sep 17 00:00:00 2001 From: Rob Wagner Date: Sat, 29 Jul 2023 16:04:51 -0400 Subject: [PATCH] Redirect on HxRequest guard failures --- src/guard.rs | 46 ++++++++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/src/guard.rs b/src/guard.rs index 9ef134a..1bcc37d 100644 --- a/src/guard.rs +++ b/src/guard.rs @@ -8,7 +8,7 @@ use std::{ }; use axum::{ - http::{Request, StatusCode}, + http::{header::LOCATION, Request, StatusCode}, response::Response, }; use futures_core::ready; @@ -22,41 +22,50 @@ use crate::HX_REQUEST; /// /// This can be used to protect routes that should only be accessed via htmx /// requests. -#[derive(Default, Debug, Clone)] -pub struct HxRequestGuardLayer; +#[derive(Debug, Clone)] +pub struct HxRequestGuardLayer<'a> { + redirect_to: &'a str, +} -impl HxRequestGuardLayer { - #[allow(clippy::default_constructed_unit_structs)] - pub fn new() -> Self { - Self::default() +impl<'a> HxRequestGuardLayer<'a> { + pub fn new(redirect_to: &'a str) -> Self { + Self { redirect_to } } } -impl Layer for HxRequestGuardLayer { - type Service = HxRequestGuard; +impl Default for HxRequestGuardLayer<'_> { + fn default() -> Self { + Self { redirect_to: "/" } + } +} + +impl<'a, S> Layer for HxRequestGuardLayer<'a> { + type Service = HxRequestGuard<'a, S>; fn layer(&self, inner: S) -> Self::Service { HxRequestGuard { inner, hx_request: false, + layer: self.clone(), } } } #[derive(Debug, Clone)] -pub struct HxRequestGuard { +pub struct HxRequestGuard<'a, S> { inner: S, hx_request: bool, + layer: HxRequestGuardLayer<'a>, } -impl Service> for HxRequestGuard +impl<'a, S, T, U> Service> for HxRequestGuard<'a, S> where S: Service, Response = Response>, U: Default, { type Response = S::Response; type Error = S::Error; - type Future = ResponseFuture; + type Future = ResponseFuture<'a, S::Future>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) @@ -73,19 +82,21 @@ where ResponseFuture { response_future, hx_request: self.hx_request, + layer: self.layer.clone(), } } } pin_project! { - pub struct ResponseFuture { + pub struct ResponseFuture<'a, F> { #[pin] response_future: F, hx_request: bool, + layer: HxRequestGuardLayer<'a>, } } -impl Future for ResponseFuture +impl<'a, F, B, E> Future for ResponseFuture<'a, F> where F: Future, E>>, B: Default, @@ -99,8 +110,11 @@ where match *this.hx_request { true => Poll::Ready(Ok(response)), false => { - let mut res = Response::new(B::default()); - *res.status_mut() = StatusCode::FORBIDDEN; + let res = Response::builder() + .status(StatusCode::SEE_OTHER) + .header(LOCATION, this.layer.redirect_to) + .body(B::default()) + .expect("failed to build response"); Poll::Ready(Ok(res)) }