#![deny(clippy::unwrap_used)] use std::{ collections::{BinaryHeap, HashMap}, env, ops::Deref, sync::Arc, }; use axum::{ error_handling::HandleErrorLayer, extract::{Multipart, Path, Query, State}, http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode, Uri}, response::{ sse::{Event, KeepAlive}, Html, IntoResponse, Redirect, Sse, }, routing::get, BoxError, Form, Router, }; use axum_htmx::{HxRedirect, HxRequest}; use axum_oidc::{ error::MiddlewareError, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcLoginLayer, }; use futures_util::Stream; use game::{Game, GameId, GameParticipantsCollector, PlayerId, RunningGamesCollector}; use garbage_collector::{start_gc, GarbageCollectorItem}; use prometheus_client::{metrics::counter::Counter, registry::Registry}; use question::{single_choice::SingleChoiceQuestion, Question}; use sailfish::TemplateOnce; use serde::{Deserialize, Serialize}; use stream::{PlayerBroadcastStream, ViewerBroadcastStream}; use tokio::{sync::RwLock, task::spawn_blocking}; use tower::ServiceBuilder; use tower_http::services::ServeDir; use tower_sessions::{cookie::SameSite, MemoryStore, SessionManagerLayer}; use crate::error::Error; type HandlerResult = Result; mod error; mod game; mod garbage_collector; mod stream; mod question; #[derive(Clone)] pub struct AppState { games: Arc>>, game_expiry: Arc>>, application_base: &'static str, prometheus_registry: Arc, metrics: Arc, } #[derive(Clone, Default)] pub struct AppMetrics { arc_games_total: Counter, } #[tokio::main] pub async fn main() { dotenvy::dotenv().ok(); env_logger::init(); let application_base = env::var("APPLICATION_BASE").expect("APPLICATION_BASE env var"); let issuer = env::var("ISSUER").expect("ISSUER env var"); let client_id = env::var("CLIENT_ID").expect("CLIENT_ID env var"); let client_secret = env::var("CLIENT_SECRET").ok(); let scopes = env::var("SCOPES") .expect("SCOPES env var") .split(' ') .map(|x| x.to_owned()) .collect::>(); let session_store = MemoryStore::default(); let session_service = ServiceBuilder::new() .layer(HandleErrorLayer::new(|_: BoxError| async { StatusCode::BAD_REQUEST })) .layer(SessionManagerLayer::new(session_store).with_same_site(SameSite::Lax)); let oidc_login_service = ServiceBuilder::new() .layer(HandleErrorLayer::new(|e: MiddlewareError| async { e.into_response() })) .layer(OidcLoginLayer::::new()); let oidc_auth_service = ServiceBuilder::new() .layer(HandleErrorLayer::new(|e: MiddlewareError| async { e.into_response() })) .layer( OidcAuthLayer::::discover_client( Uri::from_maybe_shared(application_base.clone()).expect("valid APPLICATION_BASE"), issuer.to_string(), client_id.to_string(), client_secret.to_owned(), scopes.clone(), ) .await .expect("OIDC Client"), ); let game_expiry: Arc>> = Arc::new(RwLock::new(BinaryHeap::new())); let games = Arc::new(RwLock::new(HashMap::new())); let app_metrics = Arc::new(AppMetrics::default()); let mut registry = Registry::default(); registry.register( "ars_games_total", "number of games created", app_metrics.arc_games_total.clone(), ); registry.register_collector(Box::new(RunningGamesCollector::new(games.clone()))); registry.register_collector(Box::new(GameParticipantsCollector::new(games.clone()))); start_gc(game_expiry.clone(), games.clone()); let app_state = AppState { games, game_expiry, application_base: Box::leak(application_base.into()), prometheus_registry: Arc::new(registry), metrics: app_metrics, }; let app = Router::new() .route("/", get(handle_index).post(handle_create)) .route("/:id/view", get(handle_view).post(handle_view_next)) .route("/:id/view/events", get(sse_view)) .layer(oidc_login_service) .route("/:id", get(handle_player).post(handle_player_answer)) .route("/:id/events", get(sse_player)) .route("/metrics", get(metrics)) .nest_service("/static", ServeDir::new("static")) .with_state(app_state) .layer(oidc_auth_service) .layer(session_service); axum::Server::bind(&"[::]:8080".parse().expect("valid listen address")) .serve(app.into_make_service()) .await .expect("axum server"); } pub async fn handle_index() -> HandlerResult { Ok(Html(IndexTemplate {}.render_once()?)) } pub async fn handle_create( State(state): State, OidcClaims(claims): OidcClaims, mut body: Multipart, ) -> HandlerResult { let mut quiz: Option = None; while let Some(field) = body.next_field().await? { if field.name() == Some("quizfile") { quiz = Some(toml::from_str::(&field.text().await?)?); } } let quiz = quiz.ok_or(Error::QuizFileNotFound)?; let game_id = GameId::random(); let game = Game::new(game_id.clone(), claims.subject().to_string(), quiz); let mut games = state.games.write().await; games.insert(game_id.clone(), game); let url = format!("{}/{}/view", state.application_base, &game_id.deref()); let mut game_expiry = state.game_expiry.write().await; game_expiry.push(GarbageCollectorItem::new_in(game_id, 24 * 3600)); state.metrics.arc_games_total.inc(); Ok((HxRedirect(Uri::from_maybe_shared(url.clone())?), "Ok")) } pub async fn handle_view( Path(id): Path, State(state): State, HxRequest(htmx): HxRequest, OidcClaims(claims): OidcClaims, ) -> HandlerResult { let games = state.games.read().await; let game = games.get(&id).ok_or(Error::NotFound)?; if game.owner != claims.subject().to_string() { return Err(Error::Forbidden); } Ok(Html(game.viewer_view(htmx, state.application_base).await?)) } pub async fn handle_view_next( Path(id): Path, State(state): State, OidcClaims(claims): OidcClaims, ) -> HandlerResult { let mut games = state.games.write().await; let game = games.get_mut(&id).ok_or(Error::NotFound)?; if game.owner != claims.subject().to_string() { return Err(Error::Forbidden); } game.next().await; Ok("Ok".into_response()) } pub async fn sse_view( Path(id): Path, State(state): State, OidcClaims(claims): OidcClaims, ) -> HandlerResult>>> { let games = state.games.read().await; let game = games.get(&id).ok_or(Error::NotFound)?; if game.owner != claims.subject().to_string() { return Err(Error::Forbidden); } let rx1 = game.on_state_update.subscribe(); let rx2 = game.on_submission.subscribe(); let stream = ViewerBroadcastStream::new(rx1, rx2, state.games.clone(), id, state.application_base); Ok(Sse::new(stream).keep_alive(KeepAlive::default())) } #[derive(Deserialize)] pub struct PlayerQuery { player: Option, } pub async fn handle_player( Query(query): Query, Path(id): Path, State(state): State, HxRequest(htmx): HxRequest, ) -> HandlerResult { let mut games = state.games.write().await; let game = games.get_mut(&id).ok_or(Error::NotFound)?; if let Some(player_id) = query.player.map(PlayerId::from) { Ok(Html(game.player_view(&player_id, htmx).await?).into_response()) } else { let player_id = PlayerId::random(); game.players.insert(player_id.clone()); game.on_submission.send(()); Ok(Redirect::temporary(&format!( "{}/{}?player={}", state.application_base, id.deref(), player_id.deref() )) .into_response()) } } #[derive(Deserialize)] pub struct SubmissionPayload { player_id: PlayerId, #[serde(flatten)] values: HashMap, } pub async fn handle_player_answer( Path(id): Path, State(state): State, Form(form): Form, ) -> HandlerResult { let mut games = state.games.write().await; let game = games.get_mut(&id).ok_or(Error::NotFound)?; game.handle_answer(&form.player_id, &form.values).await?; Ok(Html(game.player_view(&form.player_id, true).await?)) } #[derive(Deserialize)] pub struct SsePlayerQuery { player: PlayerId, } pub async fn sse_player( Query(query): Query, Path(id): Path, State(state): State, ) -> HandlerResult>>> { let games = state.games.read().await; let game = games.get(&id).ok_or(Error::NotFound)?; let rx = game.on_state_update.subscribe(); let stream = PlayerBroadcastStream::new(rx, state.games.clone(), id, query.player); Ok(Sse::new(stream).keep_alive(KeepAlive::default())) } async fn metrics(State(app_state): State) -> HandlerResult { let registry = app_state.prometheus_registry.clone(); let buffer = spawn_blocking::<_, Result>(move || { let mut buffer = String::new(); prometheus_client::encoding::text::encode(&mut buffer, ®istry) .map_err(Error::Prometheus)?; Ok(buffer) }) .await??; let mut headers = HeaderMap::new(); headers.insert( CONTENT_TYPE, HeaderValue::from_static("text/plain; version=0.0.4"), ); Ok((headers, buffer)) } #[derive(TemplateOnce)] #[template(path = "index.stpl")] struct IndexTemplate {} #[derive(TemplateOnce)] #[template(path = "play.stpl")] struct PlayTemplate<'a> { htmx: bool, id: &'a str, player_id: &'a str, state: PlayerState, } pub enum PlayerState { NotStarted, Answering { inner_body: String }, Waiting(u32), Result { inner_body: String }, Completed(f32), } #[derive(TemplateOnce)] #[template(path = "view.stpl")] struct ViewTemplate<'a> { htmx: bool, id: &'a str, state: ViewerState, } pub enum ViewerState { NotStarted((u32, String, String)), Answering { inner_body: String, }, Result { last_question: bool, inner_body: String, }, Completed, } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct Quiz { pub wait_for: u64, pub questions: Vec, } #[derive(Debug, Serialize, Deserialize, Clone)] #[serde(tag = "type")] pub enum QuizQuestion { #[serde(rename = "single_choice")] SingleChoice(SingleChoice), } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct SingleChoice { name: Box, answers: Box<[Box]>, correct: u32, } impl From for Box { fn from(value: QuizQuestion) -> Self { match value { QuizQuestion::SingleChoice(x) => Box::new(SingleChoiceQuestion::new(x)) as _, } } }