Add router request guard layer

This commit is contained in:
Rob Wagner 2023-07-27 21:58:53 -04:00
parent ffb2b7d66e
commit 6c0a8cde21
No known key found for this signature in database
GPG key ID: 53CCB4497B15CF61
4 changed files with 161 additions and 5 deletions

View file

@ -7,8 +7,17 @@ repository = "https://github.com/robertwayne/axum-htmx"
categories = ["web-programming"]
keywords = ["axum", "htmx", "header", "extractor"]
readme = "README.md"
version = "0.1.0"
version = "0.2.0"
edition = "2021"
[features]
default = []
guards = ["tower", "futures-core", "pin-project-lite"]
[dependencies]
axum = { git = "https://github.com/tokio-rs/axum", branch = "main", default-features = false }
# Optional dependencies
tower = { version = "0.4", default-features = false, optional = true }
futures-core = { version = "0.3", optional = true }
pin-project-lite = { version = "0.2", optional = true }

View file

@ -49,7 +49,16 @@ present, the extractor will return `None` or `false` in most cases.
| `HX-Trigger-Name` | `HxTriggerName` | `Option<String>` |
| `HX-Trigger` | `HxTrigger` | `Option<String>` |
## Example Usage
## Request Guards
__Requires features `guards`.__
In addition to the extractors, there is also a route-wide layer request guard
for the `HX-Request` header. This will return a `403: Forbidden` response if the
header is not present, which is useful if you want to make an entire router, say
`/api`, only accessible via htmx requests.
## Example: Extractors
In this example, we'll look for the `HX-Boosted` header, which is set when
applying the [hx-boost](https://htmx.org/attributes/hx-boost/) attribute to an
@ -67,6 +76,9 @@ through a boosted anchor)_, so we look for the `HX-Boosted` header and extend
from a `_partial.html` template instead.
```rs
use axum::response::IntoResponse;
use axum_htmx::HxBoosted;
async fn get_index(HxBoosted(boosted): HxBoosted) -> impl IntoResponse {
if boosted {
// Send a template extending from _partial.html
@ -76,13 +88,26 @@ async fn get_index(HxBoosted(boosted): HxBoosted) -> impl IntoResponse {
}
```
You can also take advantage of const header values:
### Example: Router Guard
```rs
let mut headers = HeaderMap::new();
headers.insert(HX_REDIRECT, HeaderValue::from_static("/some/other/page"));
use axum::Router;
use axum_htmx::HxRequestGuardLayer;
fn protected_router() -> Router {
Router::new()
.layer(HxRequestGuardLayer::new())
}
```
### Feature Flags
<!-- markdownlint-disable -->
| Flag | Default | Description | Dependencies |
|-|-|-|-|
| `guards`| Disabled | Adds request guard layers. | `tower`, `futures-core`, `pin-project-lite` |
<!-- markdownlint-enable -->
## License
`axum-htmx` is dual-licensed under either

118
src/guard.rs Normal file
View file

@ -0,0 +1,118 @@
use std::{
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use axum::{
http::{Request, StatusCode},
response::Response,
};
use futures_core::ready;
use pin_project_lite::pin_project;
use tower::{Layer, Service};
use crate::HX_REQUEST;
/// Checks if the request contains the `HX-Request` header, returning a `403:
/// Forbidden` response if the header is not present.
///
/// This can be used to protect routes that should only be accessed via htmx
/// requests.
#[derive(Default, Debug, Clone)]
pub struct HxRequestGuardLayer;
impl HxRequestGuardLayer {
#[allow(clippy::default_constructed_unit_structs)]
pub fn new() -> Self {
Self::default()
}
}
impl<S> Layer<S> for HxRequestGuardLayer {
type Service = HxRequestGuard<S>;
fn layer(&self, inner: S) -> Self::Service {
HxRequestGuard {
inner,
hx_request: false,
}
}
}
#[derive(Debug, Clone)]
pub struct HxRequestGuard<S> {
inner: S,
hx_request: bool,
}
impl<S, T, U> Service<Request<T>> for HxRequestGuard<S>
where
S: Service<Request<T>, Response = Response<U>>,
U: Default,
{
type Response = S::Response;
type Error = S::Error;
type Future = ResponseFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<T>) -> Self::Future {
// This will always contain a "true" value.
if req.headers().contains_key(HX_REQUEST) {
self.hx_request = true;
}
let response_future = self.inner.call(req);
ResponseFuture {
response_future,
hx_request: self.hx_request,
}
}
}
pin_project! {
pub struct ResponseFuture<F> {
#[pin]
response_future: F,
hx_request: bool,
}
}
impl<F, B, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response<B>, E>>,
B: Default,
{
type Output = Result<Response<B>, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let response: Response<B> = ready!(this.response_future.poll(cx))?;
match *this.hx_request {
true => Poll::Ready(Ok(response)),
false => {
let mut res = Response::new(B::default());
*res.status_mut() = StatusCode::FORBIDDEN;
Poll::Ready(Ok(res))
}
}
}
}
#[derive(Debug, Default)]
struct HxRequestGuardError;
impl fmt::Display for HxRequestGuardError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("HxRequestGuardError")
}
}
impl std::error::Error for HxRequestGuardError {}

View file

@ -1,7 +1,11 @@
#![doc = include_str!("../README.md")]
#![forbid(unsafe_code)]
pub mod extractors;
#[cfg(feature = "guards")]
pub mod guard;
pub mod headers;
pub use extractors::*;
#[cfg(feature = "guards")]
pub use guard::*;
pub use headers::*;