#![deny(clippy::unwrap_used)] use std::{ collections::{BinaryHeap, HashMap}, env, sync::Arc, }; use axum::{ async_trait, body::HttpBody, error_handling::HandleErrorLayer, extract::{FromRef, Multipart, Path, Query, State}, http::{Request, StatusCode, Uri}, response::{ sse::{Event, KeepAlive}, Html, IntoResponse, Redirect, Sse, }, routing::get, BoxError, Form, Router, }; use axum_htmx::{HxRedirect, HxRequest}; use axum_oidc::{ error::MiddlewareError, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcClient, OidcLoginLayer, }; use futures_util::Stream; use game::Game; use garbage_collector::{start_gc, GarbageCollectorItem}; use question::single_choice::SingleChoiceQuestion; use rand::{distributions, Rng}; use sailfish::TemplateOnce; use serde::{Deserialize, Serialize}; use stream::{PlayerBroadcastStream, ViewerBroadcastStream}; use tokio::sync::RwLock; use tower::{Layer, ServiceBuilder}; use tower_http::services::ServeDir; use tower_sessions::{cookie::SameSite, Expiry, MemoryStore, SessionManagerLayer}; use crate::error::Error; type HandlerResult = Result; mod error; mod game; mod garbage_collector; mod stream; mod question; #[derive(Clone)] pub struct AppState { games: Arc>>, game_expiry: Arc>>, application_base: String, } #[tokio::main] pub async fn main() { dotenvy::dotenv().ok(); env_logger::init(); let application_base = env::var("APPLICATION_BASE").expect("APPLICATION_BASE env var"); let issuer = env::var("ISSUER").expect("ISSUER env var"); let client_id = env::var("CLIENT_ID").expect("CLIENT_ID env var"); let client_secret = env::var("CLIENT_SECRET").ok(); let scopes = env::var("SCOPES") .expect("SCOPES env var") .split(' ') .map(|x| x.to_owned()) .collect::>(); let session_store = MemoryStore::default(); let session_service = ServiceBuilder::new() .layer(HandleErrorLayer::new(|_: BoxError| async { StatusCode::BAD_REQUEST })) .layer(SessionManagerLayer::new(session_store).with_same_site(SameSite::Lax)); let oidc_login_service = ServiceBuilder::new() .layer(HandleErrorLayer::new(|e: MiddlewareError| async { e.into_response() })) .layer(OidcLoginLayer::::new()); let oidc_auth_service = ServiceBuilder::new() .layer(HandleErrorLayer::new(|e: MiddlewareError| async { e.into_response() })) .layer( OidcAuthLayer::::discover_client( Uri::from_maybe_shared(application_base.clone()).expect("valid APPLICATION_BASE"), issuer.to_string(), client_id.to_string(), client_secret.to_owned(), scopes.clone(), ) .await .expect("OIDC Client"), ); let game_expiry: Arc>> = Arc::new(RwLock::new(BinaryHeap::new())); let games = Arc::new(RwLock::new(HashMap::new())); start_gc(game_expiry.clone(), games.clone()); let app_state = AppState { games, game_expiry, application_base, }; let app = Router::new() .route("/", get(handle_index).post(handle_create)) .route("/:id/view", get(handle_view).post(handle_view_next)) .route("/:id/view/events", get(sse_view)) .layer(oidc_login_service) .route("/:id", get(handle_player).post(handle_player_answer)) .route("/:id/events", get(sse_player)) .nest_service("/static", ServeDir::new("static")) .with_state(app_state) .layer(oidc_auth_service) .layer(session_service); axum::Server::bind(&"[::]:8080".parse().expect("valid listen address")) .serve(app.into_make_service()) .await .expect("axum server"); } pub async fn handle_index() -> HandlerResult { Ok(Html(IndexTemplate {}.render_once()?)) } pub async fn handle_create( State(state): State, OidcClaims(claims): OidcClaims, mut body: Multipart, ) -> HandlerResult { let mut quiz: Option = None; while let Some(field) = body.next_field().await? { if field.name() == Some("quizfile") { quiz = Some(toml::from_str::(&field.text().await?)?); } } let quiz = quiz.ok_or(Error::QuizFileNotFound)?; let game_id: String = rand::thread_rng() .sample_iter(distributions::Alphanumeric) .take(8) .map(char::from) .collect(); let game = Game::new(game_id.clone(), claims.subject().to_string(), quiz); let mut games = state.games.write().await; games.insert(game_id.clone(), game); let url = format!("{}/{}/view", state.application_base, &game_id); let mut game_expiry = state.game_expiry.write().await; game_expiry.push(GarbageCollectorItem::new_in(game_id, 24 * 3600)); Ok((HxRedirect(Uri::from_maybe_shared(url.clone())?), "Ok")) } pub async fn handle_view( Path(id): Path, State(state): State, HxRequest(htmx): HxRequest, OidcClaims(claims): OidcClaims, ) -> HandlerResult { let games = state.games.read().await; let game = games.get(&id).ok_or(Error::NotFound)?; if game.owner != claims.subject().to_string() { return Err(Error::Forbidden); } Ok(Html(game.viewer_view(htmx, &state.application_base).await?)) } pub async fn handle_view_next( Path(id): Path, State(state): State, HxRequest(htmx): HxRequest, OidcClaims(claims): OidcClaims, ) -> HandlerResult { let mut games = state.games.write().await; let game = games.get_mut(&id).ok_or(Error::NotFound)?; if game.owner != claims.subject().to_string() { return Err(Error::Forbidden); } game.next().await; Ok("Ok".into_response()) } pub async fn sse_view( Path(id): Path, State(state): State, OidcClaims(claims): OidcClaims, ) -> HandlerResult>>> { let games = state.games.read().await; let game = games.get(&id).ok_or(Error::NotFound)?; if game.owner != claims.subject().to_string() { return Err(Error::Forbidden); } let rx1 = game.on_state_update.subscribe(); let rx2 = game.on_submission.subscribe(); let stream = ViewerBroadcastStream::new( rx1, rx2, state.games.clone(), id, state.application_base.clone(), ); Ok(Sse::new(stream).keep_alive(KeepAlive::default())) } #[derive(Deserialize)] pub struct PlayerQuery { player: Option, } pub async fn handle_player( Query(query): Query, Path(id): Path, State(state): State, HxRequest(htmx): HxRequest, ) -> HandlerResult { let mut games = state.games.write().await; let game = games.get_mut(&id).ok_or(Error::NotFound)?; if let Some(player_id) = query.player { Ok(Html(game.player_view(&player_id, htmx).await?).into_response()) } else { let player_id: String = rand::thread_rng() .sample_iter(distributions::Alphanumeric) .take(32) .map(char::from) .collect(); game.players.insert(player_id.to_string()); game.on_submission.send(()); Ok(Redirect::temporary(&format!( "{}/{}?player={}", state.application_base, id, player_id )) .into_response()) } } #[derive(Deserialize)] pub struct SubmissionPayload { player_id: String, #[serde(flatten)] values: HashMap, } pub async fn handle_player_answer( Path(id): Path, State(state): State, Form(form): Form, ) -> HandlerResult { let mut games = state.games.write().await; let game = games.get_mut(&id).ok_or(Error::NotFound)?; game.handle_answer(&form.player_id, &form.values).await?; Ok(Html(game.player_view(&form.player_id, true).await?)) } #[derive(Deserialize)] pub struct SsePlayerQuery { player: String, } pub async fn sse_player( Query(query): Query, Path(id): Path, State(state): State, ) -> HandlerResult>>> { let games = state.games.read().await; let game = games.get(&id).ok_or(Error::NotFound)?; let rx = game.on_state_update.subscribe(); let stream = PlayerBroadcastStream::new(rx, state.games.clone(), id, query.player); Ok(Sse::new(stream).keep_alive(KeepAlive::default())) } #[derive(TemplateOnce)] #[template(path = "index.stpl")] struct IndexTemplate {} #[derive(TemplateOnce)] #[template(path = "play.stpl")] struct PlayTemplate<'a> { htmx: bool, id: &'a str, player_id: &'a str, state: PlayerState, } #[derive(Clone)] pub enum PlayerState { NotStarted, Answering { inner_body: String }, Waiting(u32), Result { inner_body: String }, Completed(f32), } #[derive(TemplateOnce)] #[template(path = "view.stpl")] struct ViewTemplate<'a> { htmx: bool, id: &'a str, state: ViewerState, } #[derive(Clone)] pub enum ViewerState { NotStarted((u32, String, String)), Answering { inner_body: String, }, Result { last_question: bool, inner_body: String, }, Completed, } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct Quiz { pub wait_for: u64, pub questions: Vec, } #[derive(Debug, Serialize, Deserialize, Clone)] #[serde(tag = "type")] pub enum QuizQuestion { #[serde(rename = "single_choice")] SingleChoice(SingleChoice), } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct SingleChoice { name: String, answers: Vec, correct: u32, } #[async_trait] pub trait Question: Send + Sync { async fn render_player(&self, player_id: &str, show_result: bool) -> Result; async fn handle_answer( &mut self, player_id: &str, values: &HashMap, ) -> Result<(), Error>; async fn has_answered(&self, player_id: &str) -> Result; async fn answered_correctly(&self, player_id: &str) -> Result; async fn answer_count(&self) -> Result; async fn render_viewer(&self, show_result: bool) -> Result; } impl From for Box { fn from(value: QuizQuestion) -> Self { match value { QuizQuestion::SingleChoice(x) => Box::new(SingleChoiceQuestion::new(x)) as _, } } }