use axum-oidc

This commit is contained in:
Paul Zinselmeyer 2023-11-03 19:48:48 +01:00
parent 70b0da40dc
commit ee2c7355f6
Signed by: pfzetto
GPG key ID: 4EEF46A5B276E648
3 changed files with 152 additions and 250 deletions

View file

@ -8,17 +8,22 @@ use std::{
use axum::{
async_trait,
body::HttpBody,
error_handling::HandleErrorLayer,
extract::{FromRef, Multipart, Path, Query, State},
http::Uri,
http::{Request, StatusCode, Uri},
response::{
sse::{Event, KeepAlive},
Html, IntoResponse, Redirect, Sse,
},
routing::get,
Form, Router,
BoxError, Form, Router,
};
use axum_htmx::{HxRedirect, HxRequest};
use axum_oidc::oidc::{self, EmptyAdditionalClaims, OidcApplication, OidcExtractor};
use axum_oidc::{
error::MiddlewareError, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcClient,
OidcLoginLayer,
};
use futures_util::Stream;
use game::Game;
use garbage_collector::{start_gc, GarbageCollectorItem};
@ -28,7 +33,9 @@ use sailfish::TemplateOnce;
use serde::{Deserialize, Serialize};
use stream::{PlayerBroadcastStream, ViewerBroadcastStream};
use tokio::sync::RwLock;
use tower::{Layer, ServiceBuilder};
use tower_http::services::ServeDir;
use tower_sessions::{cookie::SameSite, Expiry, MemoryStore, SessionManagerLayer};
use crate::error::Error;
@ -45,16 +52,9 @@ mod question;
pub struct AppState {
games: Arc<RwLock<HashMap<String, Game>>>,
game_expiry: Arc<RwLock<BinaryHeap<GarbageCollectorItem>>>,
oidc_application: OidcApplication<EmptyAdditionalClaims>,
application_base: String,
}
impl FromRef<AppState> for OidcApplication<EmptyAdditionalClaims> {
fn from_ref(input: &AppState) -> Self {
input.oidc_application.clone()
}
}
#[tokio::main]
pub async fn main() {
dotenvy::dotenv().ok();
@ -70,18 +70,34 @@ pub async fn main() {
.map(|x| x.to_owned())
.collect::<Vec<_>>();
let oidc_application = OidcApplication::<EmptyAdditionalClaims>::create(
application_base
.parse()
.expect("valid APPLICATION_BASE url"),
issuer.to_string(),
client_id.to_string(),
client_secret.to_owned(),
scopes.clone(),
oidc::Key::generate(),
)
.await
.expect("Oidc Authentication Client");
let session_store = MemoryStore::default();
let session_service = ServiceBuilder::new()
.layer(HandleErrorLayer::new(|_: BoxError| async {
StatusCode::BAD_REQUEST
}))
.layer(SessionManagerLayer::new(session_store).with_same_site(SameSite::Lax));
let oidc_login_service = ServiceBuilder::new()
.layer(HandleErrorLayer::new(|e: MiddlewareError| async {
e.into_response()
}))
.layer(OidcLoginLayer::<EmptyAdditionalClaims>::new());
let oidc_auth_service = ServiceBuilder::new()
.layer(HandleErrorLayer::new(|e: MiddlewareError| async {
e.into_response()
}))
.layer(
OidcAuthLayer::<EmptyAdditionalClaims>::discover_client(
Uri::from_maybe_shared(application_base.clone()).expect("valid APPLICATION_BASE"),
issuer.to_string(),
client_id.to_string(),
client_secret.to_owned(),
scopes.clone(),
)
.await
.expect("OIDC Client"),
);
let game_expiry: Arc<RwLock<BinaryHeap<GarbageCollectorItem>>> =
Arc::new(RwLock::new(BinaryHeap::new()));
@ -92,18 +108,20 @@ pub async fn main() {
let app_state = AppState {
games,
game_expiry,
oidc_application,
application_base,
};
let app = Router::new()
.route("/", get(handle_index).post(handle_create))
.route("/:id", get(handle_player).post(handle_player_answer))
.route("/:id/events", get(sse_player))
.route("/:id/view", get(handle_view).post(handle_view_next))
.route("/:id/view/events", get(sse_view))
.layer(oidc_login_service)
.route("/:id", get(handle_player).post(handle_player_answer))
.route("/:id/events", get(sse_player))
.nest_service("/static", ServeDir::new("static"))
.with_state(app_state);
.with_state(app_state)
.layer(oidc_auth_service)
.layer(session_service);
axum::Server::bind(&"[::]:8080".parse().expect("valid listen address"))
.serve(app.into_make_service())
@ -111,15 +129,13 @@ pub async fn main() {
.expect("axum server");
}
pub async fn handle_index(
oidc_extractor: OidcExtractor<EmptyAdditionalClaims>,
) -> HandlerResult<impl IntoResponse> {
pub async fn handle_index() -> HandlerResult<impl IntoResponse> {
Ok(Html(IndexTemplate {}.render_once()?))
}
pub async fn handle_create(
State(state): State<AppState>,
oidc_extractor: OidcExtractor<EmptyAdditionalClaims>,
OidcClaims(claims): OidcClaims<EmptyAdditionalClaims>,
mut body: Multipart,
) -> HandlerResult<impl IntoResponse> {
let mut quiz: Option<Quiz> = None;
@ -137,11 +153,7 @@ pub async fn handle_create(
.map(char::from)
.collect();
let game = Game::new(
game_id.clone(),
oidc_extractor.claims.subject().to_string(),
quiz,
);
let game = Game::new(game_id.clone(), claims.subject().to_string(), quiz);
let mut games = state.games.write().await;
@ -159,12 +171,12 @@ pub async fn handle_view(
Path(id): Path<String>,
State(state): State<AppState>,
HxRequest(htmx): HxRequest,
oidc_extractor: OidcExtractor<EmptyAdditionalClaims>,
OidcClaims(claims): OidcClaims<EmptyAdditionalClaims>,
) -> HandlerResult<impl IntoResponse> {
let games = state.games.read().await;
let game = games.get(&id).ok_or(Error::NotFound)?;
if game.owner != oidc_extractor.claims.subject().to_string() {
if game.owner != claims.subject().to_string() {
return Err(Error::Forbidden);
}
@ -175,12 +187,12 @@ pub async fn handle_view_next(
Path(id): Path<String>,
State(state): State<AppState>,
HxRequest(htmx): HxRequest,
oidc_extractor: OidcExtractor<EmptyAdditionalClaims>,
OidcClaims(claims): OidcClaims<EmptyAdditionalClaims>,
) -> HandlerResult<impl IntoResponse> {
let mut games = state.games.write().await;
let game = games.get_mut(&id).ok_or(Error::NotFound)?;
if game.owner != oidc_extractor.claims.subject().to_string() {
if game.owner != claims.subject().to_string() {
return Err(Error::Forbidden);
}
@ -192,12 +204,12 @@ pub async fn handle_view_next(
pub async fn sse_view(
Path(id): Path<String>,
State(state): State<AppState>,
oidc_extractor: OidcExtractor<EmptyAdditionalClaims>,
OidcClaims(claims): OidcClaims<EmptyAdditionalClaims>,
) -> HandlerResult<Sse<impl Stream<Item = Result<Event, Error>>>> {
let games = state.games.read().await;
let game = games.get(&id).ok_or(Error::NotFound)?;
if game.owner != oidc_extractor.claims.subject().to_string() {
if game.owner != claims.subject().to_string() {
return Err(Error::Forbidden);
}