diff --git a/Cargo.toml b/Cargo.toml index 98c2e22..026516b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,8 +7,17 @@ repository = "https://github.com/robertwayne/axum-htmx" categories = ["web-programming"] keywords = ["axum", "htmx", "header", "extractor"] readme = "README.md" -version = "0.1.0" +version = "0.2.0" edition = "2021" +[features] +default = [] +guards = ["tower", "futures-core", "pin-project-lite"] + [dependencies] axum = { git = "https://github.com/tokio-rs/axum", branch = "main", default-features = false } + +# Optional dependencies +tower = { version = "0.4", default-features = false, optional = true } +futures-core = { version = "0.3", optional = true } +pin-project-lite = { version = "0.2", optional = true } diff --git a/README.md b/README.md index cfdecb5..b0a78ec 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,16 @@ present, the extractor will return `None` or `false` in most cases. | `HX-Trigger-Name` | `HxTriggerName` | `Option` | | `HX-Trigger` | `HxTrigger` | `Option` | -## Example Usage +## Request Guards + +__Requires features `guards`.__ + +In addition to the extractors, there is also a route-wide layer request guard +for the `HX-Request` header. This will return a `403: Forbidden` response if the +header is not present, which is useful if you want to make an entire router, say +`/api`, only accessible via htmx requests. + +## Example: Extractors In this example, we'll look for the `HX-Boosted` header, which is set when applying the [hx-boost](https://htmx.org/attributes/hx-boost/) attribute to an @@ -67,6 +76,9 @@ through a boosted anchor)_, so we look for the `HX-Boosted` header and extend from a `_partial.html` template instead. ```rs +use axum::response::IntoResponse; +use axum_htmx::HxBoosted; + async fn get_index(HxBoosted(boosted): HxBoosted) -> impl IntoResponse { if boosted { // Send a template extending from _partial.html @@ -76,13 +88,26 @@ async fn get_index(HxBoosted(boosted): HxBoosted) -> impl IntoResponse { } ``` -You can also take advantage of const header values: +### Example: Router Guard ```rs -let mut headers = HeaderMap::new(); -headers.insert(HX_REDIRECT, HeaderValue::from_static("/some/other/page")); +use axum::Router; +use axum_htmx::HxRequestGuardLayer; + +fn protected_router() -> Router { + Router::new() + .layer(HxRequestGuardLayer::new()) +} ``` +### Feature Flags + + +| Flag | Default | Description | Dependencies | +|-|-|-|-| +| `guards`| Disabled | Adds request guard layers. | `tower`, `futures-core`, `pin-project-lite` | + + ## License `axum-htmx` is dual-licensed under either diff --git a/src/guard.rs b/src/guard.rs new file mode 100644 index 0000000..e626025 --- /dev/null +++ b/src/guard.rs @@ -0,0 +1,118 @@ +use std::{ + fmt, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use axum::{ + http::{Request, StatusCode}, + response::Response, +}; +use futures_core::ready; +use pin_project_lite::pin_project; +use tower::{Layer, Service}; + +use crate::HX_REQUEST; + +/// Checks if the request contains the `HX-Request` header, returning a `403: +/// Forbidden` response if the header is not present. +/// +/// This can be used to protect routes that should only be accessed via htmx +/// requests. +#[derive(Default, Debug, Clone)] +pub struct HxRequestGuardLayer; + +impl HxRequestGuardLayer { + #[allow(clippy::default_constructed_unit_structs)] + pub fn new() -> Self { + Self::default() + } +} + +impl Layer for HxRequestGuardLayer { + type Service = HxRequestGuard; + + fn layer(&self, inner: S) -> Self::Service { + HxRequestGuard { + inner, + hx_request: false, + } + } +} + +#[derive(Debug, Clone)] +pub struct HxRequestGuard { + inner: S, + hx_request: bool, +} + +impl Service> for HxRequestGuard +where + S: Service, Response = Response>, + U: Default, +{ + type Response = S::Response; + type Error = S::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + // This will always contain a "true" value. + if req.headers().contains_key(HX_REQUEST) { + self.hx_request = true; + } + + let response_future = self.inner.call(req); + + ResponseFuture { + response_future, + hx_request: self.hx_request, + } + } +} + +pin_project! { + pub struct ResponseFuture { + #[pin] + response_future: F, + hx_request: bool, + } +} + +impl Future for ResponseFuture +where + F: Future, E>>, + B: Default, +{ + type Output = Result, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let response: Response = ready!(this.response_future.poll(cx))?; + + match *this.hx_request { + true => Poll::Ready(Ok(response)), + false => { + let mut res = Response::new(B::default()); + *res.status_mut() = StatusCode::FORBIDDEN; + + Poll::Ready(Ok(res)) + } + } + } +} + +#[derive(Debug, Default)] +struct HxRequestGuardError; + +impl fmt::Display for HxRequestGuardError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("HxRequestGuardError") + } +} + +impl std::error::Error for HxRequestGuardError {} diff --git a/src/lib.rs b/src/lib.rs index 9ea0b88..9459590 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,11 @@ #![doc = include_str!("../README.md")] #![forbid(unsafe_code)] pub mod extractors; +#[cfg(feature = "guards")] +pub mod guard; pub mod headers; pub use extractors::*; +#[cfg(feature = "guards")] +pub use guard::*; pub use headers::*;