diff --git a/Cargo.toml b/Cargo.toml index 36dea9d..01fb0eb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ default = [] unstable = [] guards = ["tower", "futures-core", "pin-project-lite"] serde = ["dep:serde", "dep:serde_json"] +auto-vary = ["futures", "tokio", "tower"] [dependencies] axum-core = "0.4" @@ -30,9 +31,13 @@ pin-project-lite = { version = "0.2", optional = true } serde = { version = "1", features = ["derive"], optional = true } serde_json = { version = "1", optional = true } +# Optional dependencies required for the `auto-vary` feature. +tokio = { version = "1", features = ["sync"], optional = true } +futures = { version = "0.3", default-features = false, optional = true } + [dev-dependencies] axum = { version = "0.7", default-features = false } -axum-test = "14" +axum-test = "15" tokio = { version = "1", features = ["full"] } tokio-test = "0.4" diff --git a/README.md b/README.md index a7a3ee3..bee8226 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ - [Getting Started](#getting-started) - [Extractors](#extractors) - [Responders](#responders) + - [Auto Caching Management](#auto-caching-management) - [Request Guards](#request-guards) - [Examples](#examples) - [Example: Extractors](#example-extractors) @@ -76,6 +77,8 @@ any of your responses. | `HX-Trigger-After-Settle` | `HxResponseTrigger` | `axum_htmx::serde::HxEvent` | | `HX-Trigger-After-Swap` | `HxResponseTrigger` | `axum_htmx::serde::HxEvent` | +### Vary Responders + Also, there are corresponding cache-related headers, which you may want to add to `GET` responses, depending on the htmx headers. @@ -85,7 +88,7 @@ you need to add `Vary: HX-Request`. That causes the cache to be keyed based on a composite of the response URL and the `HX-Request` request header - rather than being based just on the response URL._ -Refer to [caching htmx docs section](https://htmx.org/docs/#caching) for details. +Refer to [caching htmx docs section][htmx-caching] for details. | Header | Responder | |-------------------------|---------------------| @@ -94,10 +97,27 @@ Refer to [caching htmx docs section](https://htmx.org/docs/#caching) for details | `Vary: HX-Trigger` | `VaryHxTrigger` | | `Vary: HX-Trigger-Name` | `VaryHxTriggerName` | +Look at the [Auto Caching Management](#auto-caching-management) section for +automatic `Vary` headers management. + +## Auto Caching Management + +__Requires feature `auto-vary`.__ + +Manual use of [Vary Reponders](#vary-responders) adds fragility to the code, +because of the need to manually control correspondence between used extractors +and the responders. + +We provide a [middleware](crate::AutoVaryLayer) to address this issue by +automatically adding `Vary` headers when corresponding extractors are used. +For example, on extracting [`HxRequest`], the middleware automatically adds +`Vary: hx-request` header to the response. + +Look at the usage [example][auto-vary-example]. ## Request Guards -__Requires features `guards`.__ +__Requires feature `guards`.__ In addition to the extractors, there is also a route-wide layer request guard for the `HX-Request` header. This will redirect any requests without the header @@ -207,10 +227,11 @@ fn router_two() -> Router { ## Feature Flags -| Flag | Default | Description | Dependencies | -|----------|----------|------------------------------------------------------------|---------------------------------------------| -| `guards` | Disabled | Adds request guard layers. | `tower`, `futures-core`, `pin-project-lite` | -| `serde` | Disabled | Adds serde support for the `HxEvent` and `LocationOptions` | `serde`, `serde_json` | +| Flag | Default | Description | Dependencies | +|-------------|----------|------------------------------------------------------------|---------------------------------------------| +| `auto-vary` | Disabled | A middleware to address [HTMx caching issue][htmx-caching] | `futures`, `tokio`, `tower` | +| `guards` | Disabled | Adds request guard layers. | `tower`, `futures-core`, `pin-project-lite` | +| `serde` | Disabled | Adds serde support for the `HxEvent` and `LocationOptions` | `serde`, `serde_json` | ## Contributing @@ -233,3 +254,6 @@ cargo +nightly test --all-features - **[Apache License, Version 2.0](/LICENSE-APACHE)** at your option. + +[htmx-caching]: https://htmx.org/docs/#caching +[auto-vary-example]: https://github.com/robertwayne/axum-htmx/blob/main/examples/auto-vary.rs diff --git a/examples/auto-vary.rs b/examples/auto-vary.rs new file mode 100644 index 0000000..b7868d5 --- /dev/null +++ b/examples/auto-vary.rs @@ -0,0 +1,40 @@ +//! Using `auto-vary` middleware +//! +//! Don't forget about the feature while running it: +//! `cargo run --features auto-vary --example auto-vary` +use std::time::Duration; + +use axum::{response::Html, routing::get, serve, Router}; +use axum_htmx::{AutoVaryLayer, HxRequest}; +use tokio::{net::TcpListener, time::sleep}; + +#[tokio::main] +async fn main() { + let app = Router::new() + .route("/", get(handler)) + // Add the middleware + .layer(AutoVaryLayer); + + let listener = TcpListener::bind("0.0.0.0:3000").await.unwrap(); + serve(listener, app).await.unwrap(); +} + +// Our handler differentiates full-page GET requests from HTMx-based ones by looking at the `hx-request` +// requestheader. +// +// The middleware sees the usage of the `HxRequest` extractor and automatically adds the +// `Vary: hx-request` response header. +async fn handler(HxRequest(hx_request): HxRequest) -> Html<&'static str> { + if hx_request { + // For HTMx-based GET request, it returns a partial page update + sleep(Duration::from_secs(3)).await; + return Html("HTMx response"); + } + // While for a normal GET request, it returns the whole page + Html( + r#" + +

Loading ...

+ "#, + ) +} diff --git a/src/auto_vary.rs b/src/auto_vary.rs new file mode 100644 index 0000000..c8c0729 --- /dev/null +++ b/src/auto_vary.rs @@ -0,0 +1,228 @@ +//! A middleware to automatically add a `Vary` header when needed to address +//! [HTMx caching issue](https://htmx.org/docs/#caching) + +use std::{ + sync::Arc, + task::{Context, Poll}, +}; + +use axum_core::{ + extract::Request, + response::{IntoResponse, Response}, +}; +use futures::future::{join_all, BoxFuture}; +use http::{ + header::{HeaderValue, VARY}, + Extensions, +}; +use tokio::sync::oneshot::{self, Receiver, Sender}; +use tower::{Layer, Service}; + +use crate::{ + headers::{HX_REQUEST_STR, HX_TARGET_STR, HX_TRIGGER_NAME_STR, HX_TRIGGER_STR}, + HxError, +}; +#[cfg(doc)] +use crate::{HxRequest, HxTarget, HxTrigger, HxTriggerName}; + +const MIDDLEWARE_DOUBLE_USE: &str = + "Configuration error: `axum_httpx::vary_middleware` is used twice"; + +/// Addresses [HTMx caching issue](https://htmx.org/docs/#caching) +/// by automatically adding a corresponding `Vary` header when [`HxRequest`], [`HxTarget`], +/// [`HxTrigger`], [`HxTriggerName`] or their combination is used. +#[derive(Clone)] +pub struct AutoVaryLayer; + +/// Tower service for [`AutoVaryLayer`] +#[derive(Clone)] +pub struct AutoVaryMiddleware { + inner: S, +} + +pub(crate) trait Notifier { + fn sender(&mut self) -> Option>; + + fn notify(&mut self) { + if let Some(sender) = self.sender().take() { + sender.send(()).ok(); + } + } + + fn insert(extensions: &mut Extensions) -> Receiver<()>; +} + +macro_rules! define_notifiers { + ($($name:ident),*) => { + $( + #[derive(Clone)] + pub(crate) struct $name(Option>>); + + impl Notifier for $name { + fn sender(&mut self) -> Option> { + self.0.take().and_then(Arc::into_inner) + } + + fn insert(extensions: &mut Extensions) -> Receiver<()> { + let (tx, rx) = oneshot::channel(); + if extensions.insert(Self(Some(Arc::new(tx)))).is_some() { + panic!("{}", MIDDLEWARE_DOUBLE_USE); + } + rx + } + } + )* + } +} + +define_notifiers!( + HxRequestExtracted, + HxTargetExtracted, + HxTriggerExtracted, + HxTriggerNameExtracted +); + +impl Layer for AutoVaryLayer { + type Service = AutoVaryMiddleware; + + fn layer(&self, inner: S) -> Self::Service { + AutoVaryMiddleware { inner } + } +} + +impl Service for AutoVaryMiddleware +where + S: Service + Send + 'static, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut request: Request) -> Self::Future { + let exts = request.extensions_mut(); + let rx_header = [ + (HxRequestExtracted::insert(exts), HX_REQUEST_STR), + (HxTargetExtracted::insert(exts), HX_TARGET_STR), + (HxTriggerExtracted::insert(exts), HX_TRIGGER_STR), + (HxTriggerNameExtracted::insert(exts), HX_TRIGGER_NAME_STR), + ]; + let future = self.inner.call(request); + Box::pin(async move { + let mut response: Response = future.await?; + let used_headers: Vec<_> = join_all( + rx_header + .into_iter() + .map(|(rx, header)| async move { rx.await.ok().map(|_| header) }), + ) + .await + .into_iter() + .flatten() + .collect(); + + if used_headers.is_empty() { + return Ok(response); + } + + let value = match HeaderValue::from_str(&used_headers.join(", ")) { + Ok(x) => x, + Err(e) => return Ok(HxError::from(e).into_response()), + }; + + if let Err(e) = response.headers_mut().try_append(VARY, value) { + return Ok(HxError::from(e).into_response()); + } + + Ok(response) + }) + } +} + +#[cfg(test)] +mod tests { + use axum::{routing::get, Router}; + + use super::*; + use crate::{HxRequest, HxTarget, HxTrigger, HxTriggerName}; + + fn vary_headers(resp: &axum_test::TestResponse) -> Vec { + resp.iter_headers_by_name("vary").cloned().collect() + } + + fn server() -> axum_test::TestServer { + let app = Router::new() + .route("/no-extractors", get(|| async { () })) + .route("/hx-request", get(|_: HxRequest| async { () })) + .route("/hx-target", get(|_: HxTarget| async { () })) + .route("/hx-trigger", get(|_: HxTrigger| async { () })) + .route("/hx-trigger-name", get(|_: HxTriggerName| async { () })) + .route( + "/repeated-extractor", + get(|_: HxRequest, _: HxRequest| async { () }), + ) + .route( + "/multiple-extractors", + get(|_: HxRequest, _: HxTarget, _: HxTrigger, _: HxTriggerName| async { () }), + ) + .layer(AutoVaryLayer); + axum_test::TestServer::new(app).unwrap() + } + + #[tokio::test] + async fn no_extractors() { + assert!(vary_headers(&server().get("/no-extractors").await).is_empty()); + } + + #[tokio::test] + async fn single_hx_request() { + assert_eq!( + vary_headers(&server().get("/hx-request").await), + ["hx-request"] + ); + } + + #[tokio::test] + async fn single_hx_target() { + assert_eq!( + vary_headers(&server().get("/hx-target").await), + ["hx-target"] + ); + } + + #[tokio::test] + async fn single_hx_trigger() { + assert_eq!( + vary_headers(&server().get("/hx-trigger").await), + ["hx-trigger"] + ); + } + + #[tokio::test] + async fn single_hx_trigger_name() { + assert_eq!( + vary_headers(&server().get("/hx-trigger-name").await), + ["hx-trigger-name"] + ); + } + + #[tokio::test] + async fn repeated_extractor() { + assert_eq!( + vary_headers(&server().get("/repeated-extractor").await), + ["hx-request"] + ); + } + + // Extractors can be used multiple times e.g. in middlewares + #[tokio::test] + async fn multiple_extractors() { + assert_eq!( + vary_headers(&server().get("/multiple-extractors").await), + ["hx-request, hx-target, hx-trigger, hx-trigger-name"], + ); + } +} diff --git a/src/extractors.rs b/src/extractors.rs index 81ab239..5f25683 100644 --- a/src/extractors.rs +++ b/src/extractors.rs @@ -137,6 +137,12 @@ where type Rejection = std::convert::Infallible; async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { + #[cfg(feature = "auto-vary")] + parts + .extensions + .get_mut::() + .map(crate::auto_vary::Notifier::notify); + if parts.headers.contains_key(HX_REQUEST) { return Ok(HxRequest(true)); } else { @@ -164,6 +170,12 @@ where type Rejection = std::convert::Infallible; async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { + #[cfg(feature = "auto-vary")] + parts + .extensions + .get_mut::() + .map(crate::auto_vary::Notifier::notify); + if let Some(target) = parts.headers.get(HX_TARGET) { if let Ok(target) = target.to_str() { return Ok(HxTarget(Some(target.to_string()))); @@ -193,6 +205,12 @@ where type Rejection = std::convert::Infallible; async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { + #[cfg(feature = "auto-vary")] + parts + .extensions + .get_mut::() + .map(crate::auto_vary::Notifier::notify); + if let Some(trigger_name) = parts.headers.get(HX_TRIGGER_NAME) { if let Ok(trigger_name) = trigger_name.to_str() { return Ok(HxTriggerName(Some(trigger_name.to_string()))); @@ -222,6 +240,12 @@ where type Rejection = std::convert::Infallible; async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { + #[cfg(feature = "auto-vary")] + parts + .extensions + .get_mut::() + .map(crate::auto_vary::Notifier::notify); + if let Some(trigger) = parts.headers.get(HX_TRIGGER) { if let Ok(trigger) = trigger.to_str() { return Ok(HxTrigger(Some(trigger.to_string()))); diff --git a/src/lib.rs b/src/lib.rs index 439ff68..fc1bd36 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,9 @@ mod error; pub use error::*; +#[cfg(feature = "auto-vary")] +#[cfg_attr(feature = "unstable", doc(cfg(feature = "auto-vary")))] +pub mod auto_vary; pub mod extractors; #[cfg(feature = "guards")] #[cfg_attr(feature = "unstable", doc(cfg(feature = "guards")))] @@ -12,6 +15,10 @@ pub mod guard; pub mod headers; pub mod responders; +#[cfg(feature = "auto-vary")] +#[cfg_attr(feature = "unstable", doc(cfg(feature = "auto-vary")))] +#[doc(inline)] +pub use auto_vary::*; #[doc(inline)] pub use extractors::*; #[cfg(feature = "guards")]