diff --git a/Cargo.toml b/Cargo.toml index 36dea9d..b59509a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,8 @@ serde = ["dep:serde", "dep:serde_json"] axum-core = "0.4" http = { version = "1.0", default-features = false } async-trait = "0.1" +axum = "0.7" # TODO: remove +tokio = { version = "1", features = ["sync"] } # TODO: hide behind a feature? # Optional dependencies required for the `guards` feature. tower = { version = "0.4", default-features = false, optional = true } diff --git a/src/extractors.rs b/src/extractors.rs index 81ab239..0f53604 100644 --- a/src/extractors.rs +++ b/src/extractors.rs @@ -137,6 +137,12 @@ where type Rejection = std::convert::Infallible; async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { + use crate::vary_middleware::{HxRequestExtracted, Notifier}; + parts + .extensions + .get_mut::() + .map(Notifier::notify); + if parts.headers.contains_key(HX_REQUEST) { return Ok(HxRequest(true)); } else { @@ -164,6 +170,12 @@ where type Rejection = std::convert::Infallible; async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { + use crate::vary_middleware::{HxTargetExtracted, Notifier}; + parts + .extensions + .get_mut::() + .map(Notifier::notify); + if let Some(target) = parts.headers.get(HX_TARGET) { if let Ok(target) = target.to_str() { return Ok(HxTarget(Some(target.to_string()))); diff --git a/src/lib.rs b/src/lib.rs index 439ff68..5d4dae1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,3 +22,6 @@ pub use guard::*; pub use headers::*; #[doc(inline)] pub use responders::*; + +pub(crate) mod vary_middleware; +pub use vary_middleware::vary_middleware; diff --git a/src/vary_middleware.rs b/src/vary_middleware.rs new file mode 100644 index 0000000..57a4352 --- /dev/null +++ b/src/vary_middleware.rs @@ -0,0 +1,134 @@ +use crate::{ + headers::{HX_REQUEST_STR, HX_TARGET_STR}, + HxError, +}; +use axum::{extract::Request, middleware::Next, response::Response}; +use axum_core::response::IntoResponse; +use http::{ + header::{HeaderValue, VARY}, + Extensions, +}; +use std::sync::Arc; +use tokio::sync::oneshot::{self, Receiver, Sender}; + +const MIDDLEWARE_DOUBLE_USE: &str = + "Configuration error: `axum_httpx::vary_middleware` is used twice"; + +#[derive(Clone)] +pub(crate) struct HxRequestExtracted(Option>>); + +#[derive(Clone)] +pub(crate) struct HxTargetExtracted(Option>>); + +pub trait Notifier { + fn sender(&mut self) -> Option>; + + fn notify(&mut self) { + if let Some(sender) = self.sender().take() { + sender.send(()).ok(); + } + } +} + +impl Notifier for HxRequestExtracted { + fn sender(&mut self) -> Option> { + self.0.take().and_then(Arc::into_inner) + } +} + +impl Notifier for HxTargetExtracted { + fn sender(&mut self) -> Option> { + self.0.take().and_then(Arc::into_inner) + } +} + +impl HxRequestExtracted { + fn insert_into_extensions(extensions: &mut Extensions) -> Receiver<()> { + let (tx, rx) = oneshot::channel(); + if extensions.insert(Self(Some(Arc::new(tx)))).is_some() { + panic!("{}", MIDDLEWARE_DOUBLE_USE); + } + rx + } +} + +impl HxTargetExtracted { + fn insert_into_extensions(extensions: &mut Extensions) -> Receiver<()> { + let (tx, rx) = oneshot::channel(); + if extensions.insert(Self(Some(Arc::new(tx)))).is_some() { + panic!("{}", MIDDLEWARE_DOUBLE_USE); + } + rx + } +} + +pub async fn vary_middleware(mut request: Request, next: Next) -> Response { + let hx_request_rx = HxRequestExtracted::insert_into_extensions(request.extensions_mut()); + let hx_target_rx = HxTargetExtracted::insert_into_extensions(request.extensions_mut()); + + let mut response = next.run(request).await; + + let mut used = Vec::with_capacity(4); + if hx_request_rx.await.is_ok() { + used.push(HX_REQUEST_STR) + } + if hx_target_rx.await.is_ok() { + used.push(HX_TARGET_STR) + } + + if !used.is_empty() { + let value = match HeaderValue::from_str(&used.join(", ")) { + Ok(x) => x, + Err(e) => return HxError::from(e).into_response(), + }; + if let Err(e) = response.headers_mut().try_append(VARY, value) { + return HxError::from(e).into_response(); + } + } + + response +} + +#[cfg(test)] +mod tests { + use crate::{HxRequest, HxTarget}; + use axum::{routing::get, Router}; + + use super::*; + + fn vary_headers(resp: &axum_test::TestResponse) -> Vec { + resp.iter_headers_by_name("vary").cloned().collect() + } + + #[tokio::test] + async fn multiple_headers() { + let app = Router::new() + .route("/no-extractors", get(|| async { () })) + .route("/single-extractor", get(|_: HxRequest| async { () })) + // Extractors can be used multiple times e.g. in middlewares + .route( + "/repeated-extractor", + get(|_: HxRequest, _: HxRequest| async { () }), + ) + .route( + "/multiple-extractors", + get(|_: HxRequest, _: HxTarget| async { () }), + ) + .layer(axum::middleware::from_fn(vary_middleware)); + let server = axum_test::TestServer::new(app).unwrap(); + + assert!(vary_headers(&server.get("/no-extractors").await).is_empty()); + assert_eq!( + vary_headers(&server.get("/single-extractor").await), + [HX_REQUEST_STR] + ); + assert_eq!( + vary_headers(&server.get("/repeated-extractor").await), + [HX_REQUEST_STR] + ); + assert_eq!( + vary_headers(&server.get("/multiple-extractors").await), + [format!("{HX_REQUEST_STR}, {HX_TARGET_STR}")] + ); + } +}