diff --git a/Cargo.toml b/Cargo.toml index 75c352d..928e23d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,13 @@ serde = "1.0" futures-util = "0.3" reqwest = { version = "0.11", default-features = false } urlencoding = "2.1" + +[dev-dependencies] +axum-test = "15.7.1" +testcontainers = "0.22.0" +reqwest = { version = "0.11", default-features = false, features = ["cookies"] } +tower = { version = "0.5.1", features = ["util"] } +tokio = { version = "1.40.0", features = ["full"] } +regex = "1.10.6" +serde_json = "1.0.128" +tower-sessions = { version = "0.12", default-features = false, features = [ "memory-store", "axum-core" ] } diff --git a/tests/integration.rs b/tests/integration.rs new file mode 100644 index 0000000..5b42095 --- /dev/null +++ b/tests/integration.rs @@ -0,0 +1,123 @@ + +use axum::{error_handling::HandleErrorLayer, routing::get}; +use axum_oidc::EmptyAdditionalClaims; +use keycloak::Keycloak; +use utils::handle_axum_oidc_middleware_error; + +mod keycloak; +mod utils; + +#[tokio::test(flavor = "multi_thread")] +async fn basic_login_oidc() { + let john = keycloak::User { + username: "jojo".to_string(), + email: "john.doe@example.com".to_string(), + firstname: "john".to_string(), + lastname: "doe".to_string(), + password: "jopass".to_string(), + }; + + let basic_client = keycloak::Client { + client_id: "axum-oidc-example-basic".to_string(), + client_secret: Some("123456".to_string()), + ..Default::default() + }; + + let realm_name = "test"; + + let keycloak = Keycloak::start(vec![keycloak::Realm { + name: realm_name.to_string(), + clients: vec![basic_client.clone()], + users: vec![], // Not used here, needed for id + }]) + .await + .unwrap(); + let id = keycloak.create_user(&john.username, &john.email, &john.firstname, &john.lastname, &john.password, realm_name).await; + + let keycloak_url = keycloak.url(); + let issuer = format!("{keycloak_url}/realms/{realm_name}"); + + let login_service = tower::ServiceBuilder::new() + .layer(HandleErrorLayer::new(handle_axum_oidc_middleware_error)) + .layer(axum_oidc::OidcLoginLayer::::new()); + + let oidc_client = axum_oidc::OidcAuthLayer::::discover_client( + axum::http::Uri::from_static("http://localhost:3000"), + issuer, + basic_client.client_id, + basic_client.client_secret, + vec![] + ) + .await + .expect("Cannot create OIDC client"); + + let auth_service = tower::ServiceBuilder::new() + .layer(HandleErrorLayer::new(handle_axum_oidc_middleware_error)) + .layer(oidc_client); + + let session_store = tower_sessions::MemoryStore::default(); + let session_layer = tower_sessions::SessionManagerLayer::new(session_store) + .with_same_site(tower_sessions::cookie::SameSite::None) + .with_expiry(tower_sessions::Expiry::OnInactivity( + tower_sessions::cookie::time::Duration::minutes(120), + )); + + let app = axum::Router::new() + .route("/foo", get(utils::authenticated)) + .layer(login_service) + .route("/bar", get(utils::maybe_authenticated)) + .layer(auth_service) + .layer(session_layer); + + + let server = axum_test::TestServerConfig::builder() + .save_cookies() + .http_transport() + .build_server(app) + .unwrap(); + + let client = reqwest::ClientBuilder::new() + .cookie_store(true) + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap(); + + // GET /bar + let response = server.get("/bar").await; + response.assert_status(axum_test::http::StatusCode::OK); + response.assert_text("Hello anon!"); + + // GET /foo + let response = server.get("/foo").await; + response.assert_status(axum_test::http::StatusCode::TEMPORARY_REDIRECT); + let url = utils::extract_location_header_testresponse(response).unwrap(); + + // GET keycloak/auth + let response = client.get(url).send().await.unwrap(); + assert_eq!(response.status(), reqwest::StatusCode::OK); + let html = response.text().await.unwrap(); + let url_regex = regex::Regex::new(r#"action="([^"]+)""#).unwrap(); + let url = url_regex.captures(&html).unwrap().get(1).unwrap().as_str(); + let params = [("username", "jojo"), ("password", "jopass")]; + + // POST keycloak/auth + let response = client.post(url).form(¶ms).send().await.unwrap(); + assert_eq!(response.status(), reqwest::StatusCode::FOUND); + let url = utils::extract_location_header_response(response).unwrap(); + let url = url.replace("http://localhost:3000", ""); // Remove http://localhost:3000 + + // GET /foo-callback + let response = server.get(&url).await; + response.assert_status(axum_test::http::StatusCode::TEMPORARY_REDIRECT); + response.assert_header("Location", "http://localhost:3000/foo"); + + // GET /foo + let response = server.get("/foo").await; + response.assert_status(axum_test::http::StatusCode::OK); + response.assert_text(format!("Hello {id}")); + + // GET /bar + let response = server.get("/bar").await; + response.assert_status(axum_test::http::StatusCode::OK); + response.assert_text(format!("Hello {id}! You are already logged in from another Handler.")); +} diff --git a/tests/keycloak.rs b/tests/keycloak.rs new file mode 100644 index 0000000..6490748 --- /dev/null +++ b/tests/keycloak.rs @@ -0,0 +1,272 @@ +use testcontainers::{ + core::{CmdWaitFor, ExecCommand, Image, WaitFor}, + runners::AsyncRunner, + ContainerAsync, +}; + +#[derive(Debug, Default, Clone)] +struct KeycloakImage; + +const NAME: &str = "quay.io/keycloak/keycloak"; +const TAG: &str = "25.0"; + +impl Image for KeycloakImage { + fn name(&self) -> &str { + NAME + } + + fn tag(&self) -> &str { + TAG + } + + fn ready_conditions(&self) -> Vec { + vec![WaitFor::message_on_stdout("Listening on:"), + WaitFor::message_on_stdout("Running the server in development mode. DO NOT use this configuration in production.") + ] + } + + fn env_vars( + &self, + ) -> impl IntoIterator< + Item = ( + impl Into>, + impl Into>, + ), + > { + [ + ("KEYCLOAK_ADMIN", "admin"), + ("KEYCLOAK_ADMIN_PASSWORD", "admin"), + ] + } + + fn cmd(&self) -> impl IntoIterator>> { + ["start-dev"] + } +} + +pub struct Keycloak { + container: ContainerAsync, + realms: Vec, + url: String, +} + +#[derive(Clone)] +pub struct Realm { + pub name: String, + pub clients: Vec, + pub users: Vec, +} + +#[derive(Clone)] +pub struct Client { + pub client_id: String, + pub client_secret: Option, +} + +impl Default for Client { + fn default() -> Self { + Self { + client_id: "0".to_owned(), + client_secret: None, + } + } +} + +#[derive(Clone)] +pub struct User { + pub username: String, + pub email: String, + pub firstname: String, + pub lastname: String, + pub password: String, +} + +impl Keycloak { + pub async fn start( + realms: Vec, + ) -> Result> { + let container = KeycloakImage.start().await?; + + let keycloak = Self { + url: format!( + "http://localhost:{}", + container.get_host_port_ipv4(8080).await?, + ), + container, + realms, + }; + + keycloak + .container + .exec( + ExecCommand::new([ + "/opt/keycloak/bin/kcadm.sh", + "config", + "credentials", + "--server", + "http://localhost:8080", + "--realm", + "master", + "--user", + "admin", + "--password", + "admin", + ]) + .with_cmd_ready_condition(CmdWaitFor::exit_code(0)), + ) + .await + .unwrap(); + + for realm in keycloak.realms.iter() { + if realm.name != "master" { + keycloak.create_realm(&realm.name).await; + } + for client in realm.clients.iter() { + keycloak + .create_client( + &client.client_id, + client.client_secret.as_deref(), + &realm.name, + ) + .await; + } + for user in realm.users.iter() { + keycloak + .create_user( + &user.username, + &user.email, + &user.firstname, + &user.lastname, + &user.password, + &realm.name, + ) + .await; + } + } + + Ok(keycloak) + } + + pub fn url(&self) -> &str { + &self.url + } + + pub async fn create_realm(&self, name: &str) { + self.container + .exec( + ExecCommand::new([ + "/opt/keycloak/bin/kcadm.sh", + "create", + "realms", + "-s", + &format!("realm={name}"), + "-s", + "enabled=true", + ]) + .with_cmd_ready_condition(CmdWaitFor::exit_code(0)), + ) + .await + .unwrap(); + } + + pub async fn create_client(&self, client_id: &str, client_secret: Option<&str>, realm: &str) { + if let Some(client_secret) = client_secret { + self.container + .exec( + ExecCommand::new([ + "/opt/keycloak/bin/kcadm.sh", + "create", + "clients", + "-r", + &realm, + "-s", + &format!("clientId={client_id}"), + "-s", + &format!("secret={client_secret}"), + "-s", + "redirectUris=[\"*\"]", + ]) + .with_cmd_ready_condition(CmdWaitFor::exit_code(0)), + ) + .await + .unwrap(); + } else { + self.container + .exec( + ExecCommand::new([ + "/opt/keycloak/bin/kcadm.sh", + "create", + "clients", + "-r", + &realm, + "-s", + &format!("clientId={client_id}"), + "-s", + "redirectUris=[\"*\"]", + ]) + .with_cmd_ready_condition(CmdWaitFor::exit_code(0)), + ) + .await + .unwrap(); + } + } + + pub async fn create_user( + &self, + username: &str, + email: &str, + firstname: &str, + lastname: &str, + password: &str, + realm: &str, + ) -> String { + let stderr = self.container + .exec( + ExecCommand::new([ + "/opt/keycloak/bin/kcadm.sh", + "create", + "users", + "-r", + &realm, + "-s", + &format!("username={username}"), + "-s", + "enabled=true", + "-s", + "emailVerified=true", + "-s", + &format!("email={email}"), + "-s", + &format!("firstName={firstname}"), + "-s", + &format!("lastName={lastname}"), + ]) + .with_cmd_ready_condition(CmdWaitFor::exit_code(0)), + ) + .await + .unwrap() + .stderr_to_vec() + .await.unwrap(); + + let stderr = String::from_utf8_lossy(&stderr); + let id = stderr.split('\'').nth(1).unwrap().to_string(); + + self.container + .exec( + ExecCommand::new([ + "/opt/keycloak/bin/kcadm.sh", + "set-password", + "-r", + &realm, + "--username", + username, + "--new-password", + password, + ]) + .with_cmd_ready_condition(CmdWaitFor::exit_code(0)), + ) + .await + .unwrap(); + id + } +} diff --git a/tests/utils.rs b/tests/utils.rs new file mode 100644 index 0000000..75ae578 --- /dev/null +++ b/tests/utils.rs @@ -0,0 +1,48 @@ +use axum::response::IntoResponse; +use axum_oidc::{EmptyAdditionalClaims, OidcClaims}; + + +pub async fn authenticated(claims: OidcClaims) -> impl IntoResponse { + format!("Hello {}", claims.subject().as_str()) +} + +pub async fn maybe_authenticated( + claims: Option>, +) -> impl IntoResponse { + if let Some(claims) = claims { + format!( + "Hello {}! You are already logged in from another Handler.", + claims.subject().as_str() + ) + } else { + "Hello anon!".to_string() + } +} + +pub async fn handle_axum_oidc_middleware_error( + e: axum_oidc::error::MiddlewareError, +) -> axum::http::Response { + e.into_response() +} + +pub fn extract_location_header_testresponse(response: axum_test::TestResponse) -> Option { + Some( + response + .headers() + .get("Location")? + .to_str() + .ok()? + .to_string(), + ) +} + +pub fn extract_location_header_response(response: reqwest::Response) -> Option { + Some( + response + .headers() + .get("Location")? + .to_str() + .ok()? + .to_string(), + ) +}