Merge pull request #2 from robertwayne/guard-redirects (closes #1)

HxRequestGuard now redirects on failure
This commit is contained in:
Rob 2023-07-29 17:08:51 -04:00 committed by GitHub
commit 0476a16ebc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 24 deletions

View file

@ -54,9 +54,8 @@ present, the extractor will return `None` or `false` in most cases.
__Requires features `guards`.__ __Requires features `guards`.__
In addition to the extractors, there is also a route-wide layer request guard In addition to the extractors, there is also a route-wide layer request guard
for the `HX-Request` header. This will return a `403: Forbidden` response if the for the `HX-Request` header. This will redirect any requests without the header
header is not present, which is useful if you want to make an entire router, say to "/" by default.
`/api`, only accessible via htmx requests.
_It should be noted that this is NOT a replacement for an auth guard. A user can _It should be noted that this is NOT a replacement for an auth guard. A user can
trivially set the `HX-Request` header themselves. This is merely a convenience trivially set the `HX-Request` header themselves. This is merely a convenience
@ -101,7 +100,13 @@ use axum_htmx::HxRequestGuardLayer;
fn protected_router() -> Router { fn protected_router() -> Router {
Router::new() Router::new()
.layer(HxRequestGuardLayer::new()) // Redirects to "/" if the HX-Request header is not present
.layer(HxRequestGuardLayer::default())
}
fn other_route() -> Router {
Router::new()
.layer(HxRequestGuardLayer::new("/redirect-to-this-route"))
} }
``` ```

View file

@ -8,7 +8,7 @@ use std::{
}; };
use axum::{ use axum::{
http::{Request, StatusCode}, http::{header::LOCATION, Request, StatusCode},
response::Response, response::Response,
}; };
use futures_core::ready; use futures_core::ready;
@ -17,46 +17,55 @@ use tower::{Layer, Service};
use crate::HX_REQUEST; use crate::HX_REQUEST;
/// Checks if the request contains the `HX-Request` header, returning a `403: /// Checks if the request contains the `HX-Request` header, redirecting to the
/// Forbidden` response if the header is not present. /// given location if not.
/// ///
/// This can be used to protect routes that should only be accessed via htmx /// This can be useful for preventing users from accidently ending up on a route
/// requests. /// which would otherwise return only partial HTML data.
#[derive(Default, Debug, Clone)] #[derive(Debug, Clone)]
pub struct HxRequestGuardLayer; pub struct HxRequestGuardLayer<'a> {
redirect_to: &'a str,
}
impl HxRequestGuardLayer { impl<'a> HxRequestGuardLayer<'a> {
#[allow(clippy::default_constructed_unit_structs)] pub fn new(redirect_to: &'a str) -> Self {
pub fn new() -> Self { Self { redirect_to }
Self::default()
} }
} }
impl<S> Layer<S> for HxRequestGuardLayer { impl Default for HxRequestGuardLayer<'_> {
type Service = HxRequestGuard<S>; fn default() -> Self {
Self { redirect_to: "/" }
}
}
impl<'a, S> Layer<S> for HxRequestGuardLayer<'a> {
type Service = HxRequestGuard<'a, S>;
fn layer(&self, inner: S) -> Self::Service { fn layer(&self, inner: S) -> Self::Service {
HxRequestGuard { HxRequestGuard {
inner, inner,
hx_request: false, hx_request: false,
layer: self.clone(),
} }
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct HxRequestGuard<S> { pub struct HxRequestGuard<'a, S> {
inner: S, inner: S,
hx_request: bool, hx_request: bool,
layer: HxRequestGuardLayer<'a>,
} }
impl<S, T, U> Service<Request<T>> for HxRequestGuard<S> impl<'a, S, T, U> Service<Request<T>> for HxRequestGuard<'a, S>
where where
S: Service<Request<T>, Response = Response<U>>, S: Service<Request<T>, Response = Response<U>>,
U: Default, U: Default,
{ {
type Response = S::Response; type Response = S::Response;
type Error = S::Error; type Error = S::Error;
type Future = ResponseFuture<S::Future>; type Future = ResponseFuture<'a, S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx) self.inner.poll_ready(cx)
@ -73,19 +82,21 @@ where
ResponseFuture { ResponseFuture {
response_future, response_future,
hx_request: self.hx_request, hx_request: self.hx_request,
layer: self.layer.clone(),
} }
} }
} }
pin_project! { pin_project! {
pub struct ResponseFuture<F> { pub struct ResponseFuture<'a, F> {
#[pin] #[pin]
response_future: F, response_future: F,
hx_request: bool, hx_request: bool,
layer: HxRequestGuardLayer<'a>,
} }
} }
impl<F, B, E> Future for ResponseFuture<F> impl<'a, F, B, E> Future for ResponseFuture<'a, F>
where where
F: Future<Output = Result<Response<B>, E>>, F: Future<Output = Result<Response<B>, E>>,
B: Default, B: Default,
@ -99,8 +110,11 @@ where
match *this.hx_request { match *this.hx_request {
true => Poll::Ready(Ok(response)), true => Poll::Ready(Ok(response)),
false => { false => {
let mut res = Response::new(B::default()); let res = Response::builder()
*res.status_mut() = StatusCode::FORBIDDEN; .status(StatusCode::SEE_OTHER)
.header(LOCATION, this.layer.redirect_to)
.body(B::default())
.expect("failed to build response");
Poll::Ready(Ok(res)) Poll::Ready(Ok(res))
} }