diff --git a/src/extractors.rs b/src/extractors.rs index 0f53604..7a90f11 100644 --- a/src/extractors.rs +++ b/src/extractors.rs @@ -137,11 +137,10 @@ where type Rejection = std::convert::Infallible; async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { - use crate::vary_middleware::{HxRequestExtracted, Notifier}; parts .extensions - .get_mut::() - .map(Notifier::notify); + .get_mut::() + .map(crate::vary_middleware::Notifier::notify); if parts.headers.contains_key(HX_REQUEST) { return Ok(HxRequest(true)); @@ -170,11 +169,10 @@ where type Rejection = std::convert::Infallible; async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { - use crate::vary_middleware::{HxTargetExtracted, Notifier}; parts .extensions - .get_mut::() - .map(Notifier::notify); + .get_mut::() + .map(crate::vary_middleware::Notifier::notify); if let Some(target) = parts.headers.get(HX_TARGET) { if let Ok(target) = target.to_str() { @@ -205,6 +203,11 @@ where type Rejection = std::convert::Infallible; async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { + parts + .extensions + .get_mut::() + .map(crate::vary_middleware::Notifier::notify); + if let Some(trigger_name) = parts.headers.get(HX_TRIGGER_NAME) { if let Ok(trigger_name) = trigger_name.to_str() { return Ok(HxTriggerName(Some(trigger_name.to_string()))); @@ -234,6 +237,11 @@ where type Rejection = std::convert::Infallible; async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { + parts + .extensions + .get_mut::() + .map(crate::vary_middleware::Notifier::notify); + if let Some(trigger) = parts.headers.get(HX_TRIGGER) { if let Ok(trigger) = trigger.to_str() { return Ok(HxTrigger(Some(trigger.to_string()))); diff --git a/src/vary_middleware.rs b/src/vary_middleware.rs index ccf3080..9349008 100644 --- a/src/vary_middleware.rs +++ b/src/vary_middleware.rs @@ -9,7 +9,7 @@ use http::{ use tokio::sync::oneshot::{self, Receiver, Sender}; use crate::{ - headers::{HX_REQUEST_STR, HX_TARGET_STR}, + headers::{HX_REQUEST_STR, HX_TARGET_STR, HX_TRIGGER_NAME_STR, HX_TRIGGER_STR}, HxError, }; @@ -22,6 +22,12 @@ pub(crate) struct HxRequestExtracted(Option>>); #[derive(Clone)] pub(crate) struct HxTargetExtracted(Option>>); +#[derive(Clone)] +pub(crate) struct HxTriggerExtracted(Option>>); + +#[derive(Clone)] +pub(crate) struct HxTriggerNameExtracted(Option>>); + pub trait Notifier { fn sender(&mut self) -> Option>; @@ -44,6 +50,18 @@ impl Notifier for HxTargetExtracted { } } +impl Notifier for HxTriggerExtracted { + fn sender(&mut self) -> Option> { + self.0.take().and_then(Arc::into_inner) + } +} + +impl Notifier for HxTriggerNameExtracted { + fn sender(&mut self) -> Option> { + self.0.take().and_then(Arc::into_inner) + } +} + impl HxRequestExtracted { fn insert_into_extensions(extensions: &mut Extensions) -> Receiver<()> { let (tx, rx) = oneshot::channel(); @@ -64,9 +82,32 @@ impl HxTargetExtracted { } } +impl HxTriggerExtracted { + fn insert_into_extensions(extensions: &mut Extensions) -> Receiver<()> { + let (tx, rx) = oneshot::channel(); + if extensions.insert(Self(Some(Arc::new(tx)))).is_some() { + panic!("{}", MIDDLEWARE_DOUBLE_USE); + } + rx + } +} + +impl HxTriggerNameExtracted { + fn insert_into_extensions(extensions: &mut Extensions) -> Receiver<()> { + let (tx, rx) = oneshot::channel(); + if extensions.insert(Self(Some(Arc::new(tx)))).is_some() { + panic!("{}", MIDDLEWARE_DOUBLE_USE); + } + rx + } +} + pub async fn vary_middleware(mut request: Request, next: Next) -> Response { let hx_request_rx = HxRequestExtracted::insert_into_extensions(request.extensions_mut()); let hx_target_rx = HxTargetExtracted::insert_into_extensions(request.extensions_mut()); + let hx_trigger_rx = HxTriggerExtracted::insert_into_extensions(request.extensions_mut()); + let hx_trigger_name_rx = + HxTriggerNameExtracted::insert_into_extensions(request.extensions_mut()); let mut response = next.run(request).await; @@ -77,6 +118,12 @@ pub async fn vary_middleware(mut request: Request, next: Next) -> Response { if hx_target_rx.await.is_ok() { used.push(HX_TARGET_STR) } + if hx_trigger_rx.await.is_ok() { + used.push(HX_TRIGGER_STR) + } + if hx_trigger_name_rx.await.is_ok() { + used.push(HX_TRIGGER_NAME_STR) + } if !used.is_empty() { let value = match HeaderValue::from_str(&used.join(", ")) { @@ -96,41 +143,82 @@ mod tests { use axum::{routing::get, Router}; use super::*; - use crate::{HxRequest, HxTarget}; + use crate::{HxRequest, HxTarget, HxTrigger, HxTriggerName}; fn vary_headers(resp: &axum_test::TestResponse) -> Vec { resp.iter_headers_by_name("vary").cloned().collect() } - #[tokio::test] - async fn multiple_headers() { + fn server() -> axum_test::TestServer { let app = Router::new() .route("/no-extractors", get(|| async { () })) - .route("/single-extractor", get(|_: HxRequest| async { () })) - // Extractors can be used multiple times e.g. in middlewares + .route("/hx-request", get(|_: HxRequest| async { () })) + .route("/hx-target", get(|_: HxTarget| async { () })) + .route("/hx-trigger", get(|_: HxTrigger| async { () })) + .route("/hx-trigger-name", get(|_: HxTriggerName| async { () })) .route( "/repeated-extractor", get(|_: HxRequest, _: HxRequest| async { () }), ) .route( "/multiple-extractors", - get(|_: HxRequest, _: HxTarget| async { () }), + get(|_: HxRequest, _: HxTarget, _: HxTrigger, _: HxTriggerName| async { () }), ) .layer(axum::middleware::from_fn(vary_middleware)); - let server = axum_test::TestServer::new(app).unwrap(); + axum_test::TestServer::new(app).unwrap() + } - assert!(vary_headers(&server.get("/no-extractors").await).is_empty()); + #[tokio::test] + async fn no_extractors() { + assert!(vary_headers(&server().get("/no-extractors").await).is_empty()); + } + + #[tokio::test] + async fn single_hx_request() { assert_eq!( - vary_headers(&server.get("/single-extractor").await), - [HX_REQUEST_STR] + vary_headers(&server().get("/hx-request").await), + ["hx-request"] ); + } + + #[tokio::test] + async fn single_hx_target() { assert_eq!( - vary_headers(&server.get("/repeated-extractor").await), - [HX_REQUEST_STR] + vary_headers(&server().get("/hx-target").await), + ["hx-target"] ); + } + + #[tokio::test] + async fn single_hx_trigger() { assert_eq!( - vary_headers(&server.get("/multiple-extractors").await), - [format!("{HX_REQUEST_STR}, {HX_TARGET_STR}")] + vary_headers(&server().get("/hx-trigger").await), + ["hx-trigger"] + ); + } + + #[tokio::test] + async fn single_hx_trigger_name() { + assert_eq!( + vary_headers(&server().get("/hx-trigger-name").await), + ["hx-trigger-name"] + ); + } + + #[tokio::test] + async fn repeated_extractor() { + assert_eq!( + vary_headers(&server().get("/repeated-extractor").await), + ["hx-request"] + ); + } + + // Extractors can be used multiple times e.g. in middlewares + #[tokio::test] + async fn multiple_extractors() { + assert_eq!( + vary_headers(&server().get("/multiple-extractors").await), + ["hx-request, hx-target, hx-trigger, hx-trigger-name"], ); } }