mirror of
https://github.com/robertwayne/axum-htmx
synced 2025-01-27 00:49:01 +01:00
Redirect on HxRequest guard failures
This commit is contained in:
parent
35f86927e7
commit
a07426695a
1 changed files with 30 additions and 16 deletions
46
src/guard.rs
46
src/guard.rs
|
@ -8,7 +8,7 @@ use std::{
|
|||
};
|
||||
|
||||
use axum::{
|
||||
http::{Request, StatusCode},
|
||||
http::{header::LOCATION, Request, StatusCode},
|
||||
response::Response,
|
||||
};
|
||||
use futures_core::ready;
|
||||
|
@ -22,41 +22,50 @@ use crate::HX_REQUEST;
|
|||
///
|
||||
/// This can be used to protect routes that should only be accessed via htmx
|
||||
/// requests.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct HxRequestGuardLayer;
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HxRequestGuardLayer<'a> {
|
||||
redirect_to: &'a str,
|
||||
}
|
||||
|
||||
impl HxRequestGuardLayer {
|
||||
#[allow(clippy::default_constructed_unit_structs)]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
impl<'a> HxRequestGuardLayer<'a> {
|
||||
pub fn new(redirect_to: &'a str) -> Self {
|
||||
Self { redirect_to }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Layer<S> for HxRequestGuardLayer {
|
||||
type Service = HxRequestGuard<S>;
|
||||
impl Default for HxRequestGuardLayer<'_> {
|
||||
fn default() -> Self {
|
||||
Self { redirect_to: "/" }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, S> Layer<S> for HxRequestGuardLayer<'a> {
|
||||
type Service = HxRequestGuard<'a, S>;
|
||||
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
HxRequestGuard {
|
||||
inner,
|
||||
hx_request: false,
|
||||
layer: self.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HxRequestGuard<S> {
|
||||
pub struct HxRequestGuard<'a, S> {
|
||||
inner: S,
|
||||
hx_request: bool,
|
||||
layer: HxRequestGuardLayer<'a>,
|
||||
}
|
||||
|
||||
impl<S, T, U> Service<Request<T>> for HxRequestGuard<S>
|
||||
impl<'a, S, T, U> Service<Request<T>> for HxRequestGuard<'a, S>
|
||||
where
|
||||
S: Service<Request<T>, Response = Response<U>>,
|
||||
U: Default,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = S::Error;
|
||||
type Future = ResponseFuture<S::Future>;
|
||||
type Future = ResponseFuture<'a, S::Future>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx)
|
||||
|
@ -73,19 +82,21 @@ where
|
|||
ResponseFuture {
|
||||
response_future,
|
||||
hx_request: self.hx_request,
|
||||
layer: self.layer.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
pub struct ResponseFuture<F> {
|
||||
pub struct ResponseFuture<'a, F> {
|
||||
#[pin]
|
||||
response_future: F,
|
||||
hx_request: bool,
|
||||
layer: HxRequestGuardLayer<'a>,
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, B, E> Future for ResponseFuture<F>
|
||||
impl<'a, F, B, E> Future for ResponseFuture<'a, F>
|
||||
where
|
||||
F: Future<Output = Result<Response<B>, E>>,
|
||||
B: Default,
|
||||
|
@ -99,8 +110,11 @@ where
|
|||
match *this.hx_request {
|
||||
true => Poll::Ready(Ok(response)),
|
||||
false => {
|
||||
let mut res = Response::new(B::default());
|
||||
*res.status_mut() = StatusCode::FORBIDDEN;
|
||||
let res = Response::builder()
|
||||
.status(StatusCode::SEE_OTHER)
|
||||
.header(LOCATION, this.layer.redirect_to)
|
||||
.body(B::default())
|
||||
.expect("failed to build response");
|
||||
|
||||
Poll::Ready(Ok(res))
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue