diff --git a/Cargo.toml b/Cargo.toml index d3c3655..0fd76a3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,12 +15,19 @@ default = [] unstable = [] guards = ["tower", "futures-core", "pin-project-lite"] serde = ["dep:serde", "dep:serde_json"] +derive = ["dep:axum-htmx-derive"] + +[workspace] +members = ["axum-htmx-derive"] [dependencies] axum-core = "0.4" http = { version = "1.0", default-features = false } async-trait = "0.1" +# Workspace dependencies +axum-htmx-derive = { path = "axum-htmx-derive", optional = true } + # Optional dependencies required for the `guards` feature. tower = { version = "0.4", default-features = false, optional = true } futures-core = { version = "0.3", optional = true } diff --git a/axum-htmx-derive/Cargo.toml b/axum-htmx-derive/Cargo.toml new file mode 100644 index 0000000..e09e666 --- /dev/null +++ b/axum-htmx-derive/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "axum-htmx-derive" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +proc-macro-error = "1.0" +proc-macro2 = "1.0" +quote = "1.0" +syn = { version = "1.0", features = ["extra-traits", "full", "fold"] } + +[dev-dependencies] +colored-diff = "0.2.3" diff --git a/axum-htmx-derive/README.md b/axum-htmx-derive/README.md new file mode 100644 index 0000000..fd0b70d --- /dev/null +++ b/axum-htmx-derive/README.md @@ -0,0 +1,3 @@ +# axum-htmx-derive + +This is an internal helper library of [`axum-htmx`](https://docs.rs/axum-htmx/latest/axum_htmx/). diff --git a/axum-htmx-derive/src/boosted_by/boosted_by.rs b/axum-htmx-derive/src/boosted_by/boosted_by.rs new file mode 100644 index 0000000..ff72713 --- /dev/null +++ b/axum-htmx-derive/src/boosted_by/boosted_by.rs @@ -0,0 +1,97 @@ +use proc_macro2::TokenStream; +use proc_macro_error::abort; +use quote::quote; +use syn::{parse2, parse_quote, parse_str, ItemFn}; + +pub struct MacroInput { + pub source_fn: ItemFn, + pub layout_fn: String, + pub fn_args: Vec, +} + +pub fn parse_macros_input( + args: TokenStream, + input: TokenStream, +) -> Result { + let mut args_iter = args.clone().into_iter().map(|arg| arg.to_string()); + + // get layout_fn from args + let layout_fn = match args_iter.next() { + Some(arg) => arg, + None => abort!( + args, + "boosted_by requires layout function (to produce non-boosted response) as an argument." + ), + }; + + // arguments for callable function + let fn_args = args_iter.collect::>(); + + // parse input as ItemFn + let source_fn = match parse2::(input) { + Ok(syntax_tree) => syntax_tree, + Err(error) => return Err(error.to_compile_error()), + }; + + Ok(MacroInput { + source_fn, + layout_fn, + fn_args, + }) +} + +pub fn transform(input: MacroInput) -> ItemFn { + let template_fn: ItemFn = parse_quote!( + fn index(axum_htmx::HxBoosted(boosted): axum_htmx::HxBoosted) { + if boosted { + result_boosted + } else { + layout_fn(result_with_layout, fn_args) + } + } + ); + + transform_using_template(input, template_fn) +} + +pub fn transform_async(input: MacroInput) -> ItemFn { + let template_fn: ItemFn = parse_quote!( + fn index(axum_htmx::HxBoosted(boosted): axum_htmx::HxBoosted) { + if boosted { + result_boosted + } else { + layout_fn(result_with_layout, fn_args).await + } + } + ); + + transform_using_template(input, template_fn) +} + +pub fn transform_using_template(input: MacroInput, template_fn: ItemFn) -> ItemFn { + let mut source_fn = input.source_fn.clone(); + + // add HxBoosted input to source_fn + let hx_boosted_input = template_fn.sig.inputs.first().unwrap().clone(); + source_fn.sig.inputs.push(hx_boosted_input); + + // pop the last statement and wrap it with if-else + let modify_stmt = source_fn.block.stmts.pop().unwrap(); + let modify_stmt = quote!(#modify_stmt).to_string(); + let modify_args = input.fn_args.join(""); + + let new_fn_str = quote!(#template_fn) + .to_string() + .replace("layout_fn", input.layout_fn.as_str()) + .replace("result_boosted", modify_stmt.as_str()) + .replace("result_with_layout", modify_stmt.as_str()) + .replace(", fn_args", modify_args.as_str()); + + let new_fn: ItemFn = parse_str(new_fn_str.as_str()).unwrap(); + let new_fn_stmt = new_fn.block.stmts.first().unwrap().clone(); + + // push the new statement to source_fn + source_fn.block.stmts.push(new_fn_stmt); + + source_fn.to_owned() +} diff --git a/axum-htmx-derive/src/boosted_by/mod.rs b/axum-htmx-derive/src/boosted_by/mod.rs new file mode 100644 index 0000000..cceade5 --- /dev/null +++ b/axum-htmx-derive/src/boosted_by/mod.rs @@ -0,0 +1,27 @@ +use proc_macro2::TokenStream; +use quote::quote; + +#[cfg(test)] +mod tests; + +mod boosted_by; + +pub fn macros(args: TokenStream, input: TokenStream) -> TokenStream { + match boosted_by::parse_macros_input(args, input) { + Ok(macros_input) => { + let new_item_fn = boosted_by::transform(macros_input); + quote!(#new_item_fn) + } + Err(error) => error, + } +} + +pub fn macros_async(args: TokenStream, input: TokenStream) -> TokenStream { + match boosted_by::parse_macros_input(args, input) { + Ok(macros_input) => { + let new_item_fn = boosted_by::transform_async(macros_input); + quote!(#new_item_fn) + } + Err(error) => error, + } +} diff --git a/axum-htmx-derive/src/boosted_by/tests.rs b/axum-htmx-derive/src/boosted_by/tests.rs new file mode 100644 index 0000000..5fc5e45 --- /dev/null +++ b/axum-htmx-derive/src/boosted_by/tests.rs @@ -0,0 +1,53 @@ +#![cfg(test)] + +use proc_macro2::TokenStream; +use quote::quote; + +use super::macros; + +#[test] +fn boosted_by() { + let before = quote! { + async fn index(Path(user_id): Path) -> Html { + let ctx = HomeTemplate { + locale: "en".to_string(), + }; + + Html(ctx.render_once().unwrap_or(String::new())) + } + }; + let expected = quote! { + async fn index(Path(user_id): Path, axum_htmx::HxBoosted(boosted): axum_htmx::HxBoosted) -> Html { + let ctx = HomeTemplate { + locale: "en".to_string(), + }; + + if boosted { + Html(ctx.render_once().unwrap_or(String::new())) + } else { + with_layout(Html(ctx.render_once().unwrap_or(String::new())), state1, state2) + } + } + }; + + let after = macros(quote! {with_layout, state1, state2}, before); + + assert_tokens_eq(&expected, &after); +} + +fn assert_tokens_eq(expected: &TokenStream, actual: &TokenStream) { + let expected = expected.to_string(); + let actual = actual.to_string(); + + if expected != actual { + println!( + "{}", + colored_diff::PrettyDifference { + expected: &expected, + actual: &actual, + } + ); + + panic!("expected != actual"); + } +} diff --git a/axum-htmx-derive/src/lib.rs b/axum-htmx-derive/src/lib.rs new file mode 100644 index 0000000..5efdc30 --- /dev/null +++ b/axum-htmx-derive/src/lib.rs @@ -0,0 +1,17 @@ +#![doc = include_str!("../README.md")] +use proc_macro::TokenStream; +use proc_macro_error::proc_macro_error; + +mod boosted_by; + +#[proc_macro_error] +#[proc_macro_attribute] +pub fn hx_boosted_by(args: TokenStream, input: TokenStream) -> TokenStream { + boosted_by::macros(args.into(), input.into()).into() +} + +#[proc_macro_error] +#[proc_macro_attribute] +pub fn hx_boosted_by_async(args: TokenStream, input: TokenStream) -> TokenStream { + boosted_by::macros_async(args.into(), input.into()).into() +} diff --git a/src/lib.rs b/src/lib.rs index 439ff68..a5e4af6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,3 +22,7 @@ pub use guard::*; pub use headers::*; #[doc(inline)] pub use responders::*; + +#[cfg(feature = "derive")] +#[cfg_attr(feature = "unstable", doc(cfg(feature = "derive")))] +pub use axum_htmx_derive::*;