diff --git a/README.md b/README.md index a4c8246..ca3de95 100644 --- a/README.md +++ b/README.md @@ -54,9 +54,8 @@ present, the extractor will return `None` or `false` in most cases. __Requires features `guards`.__ In addition to the extractors, there is also a route-wide layer request guard -for the `HX-Request` header. This will return a `403: Forbidden` response if the -header is not present, which is useful if you want to make an entire router, say -`/api`, only accessible via htmx requests. +for the `HX-Request` header. This will redirect any requests without the header +to "/" by default. _It should be noted that this is NOT a replacement for an auth guard. A user can trivially set the `HX-Request` header themselves. This is merely a convenience @@ -101,7 +100,13 @@ use axum_htmx::HxRequestGuardLayer; fn protected_router() -> Router { Router::new() - .layer(HxRequestGuardLayer::new()) + // Redirects to "/" if the HX-Request header is not present + .layer(HxRequestGuardLayer::default()) +} + +fn other_route() -> Router { + Router::new() + .layer(HxRequestGuardLayer::new("/redirect-to-this-route")) } ``` diff --git a/src/guard.rs b/src/guard.rs index 9ef134a..66b440b 100644 --- a/src/guard.rs +++ b/src/guard.rs @@ -8,7 +8,7 @@ use std::{ }; use axum::{ - http::{Request, StatusCode}, + http::{header::LOCATION, Request, StatusCode}, response::Response, }; use futures_core::ready; @@ -17,46 +17,55 @@ use tower::{Layer, Service}; use crate::HX_REQUEST; -/// Checks if the request contains the `HX-Request` header, returning a `403: -/// Forbidden` response if the header is not present. +/// Checks if the request contains the `HX-Request` header, redirecting to the +/// given location if not. /// -/// This can be used to protect routes that should only be accessed via htmx -/// requests. -#[derive(Default, Debug, Clone)] -pub struct HxRequestGuardLayer; +/// This can be useful for preventing users from accidently ending up on a route +/// which would otherwise return only partial HTML data. +#[derive(Debug, Clone)] +pub struct HxRequestGuardLayer<'a> { + redirect_to: &'a str, +} -impl HxRequestGuardLayer { - #[allow(clippy::default_constructed_unit_structs)] - pub fn new() -> Self { - Self::default() +impl<'a> HxRequestGuardLayer<'a> { + pub fn new(redirect_to: &'a str) -> Self { + Self { redirect_to } } } -impl Layer for HxRequestGuardLayer { - type Service = HxRequestGuard; +impl Default for HxRequestGuardLayer<'_> { + fn default() -> Self { + Self { redirect_to: "/" } + } +} + +impl<'a, S> Layer for HxRequestGuardLayer<'a> { + type Service = HxRequestGuard<'a, S>; fn layer(&self, inner: S) -> Self::Service { HxRequestGuard { inner, hx_request: false, + layer: self.clone(), } } } #[derive(Debug, Clone)] -pub struct HxRequestGuard { +pub struct HxRequestGuard<'a, S> { inner: S, hx_request: bool, + layer: HxRequestGuardLayer<'a>, } -impl Service> for HxRequestGuard +impl<'a, S, T, U> Service> for HxRequestGuard<'a, S> where S: Service, Response = Response>, U: Default, { type Response = S::Response; type Error = S::Error; - type Future = ResponseFuture; + type Future = ResponseFuture<'a, S::Future>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) @@ -73,19 +82,21 @@ where ResponseFuture { response_future, hx_request: self.hx_request, + layer: self.layer.clone(), } } } pin_project! { - pub struct ResponseFuture { + pub struct ResponseFuture<'a, F> { #[pin] response_future: F, hx_request: bool, + layer: HxRequestGuardLayer<'a>, } } -impl Future for ResponseFuture +impl<'a, F, B, E> Future for ResponseFuture<'a, F> where F: Future, E>>, B: Default, @@ -99,8 +110,11 @@ where match *this.hx_request { true => Poll::Ready(Ok(response)), false => { - let mut res = Response::new(B::default()); - *res.status_mut() = StatusCode::FORBIDDEN; + let res = Response::builder() + .status(StatusCode::SEE_OTHER) + .header(LOCATION, this.layer.redirect_to) + .body(B::default()) + .expect("failed to build response"); Poll::Ready(Ok(res)) }