use axum-oidc
This commit is contained in:
parent
70b0da40dc
commit
ee2c7355f6
3 changed files with 152 additions and 250 deletions
94
src/main.rs
94
src/main.rs
|
|
@ -8,17 +8,22 @@ use std::{
|
|||
|
||||
use axum::{
|
||||
async_trait,
|
||||
body::HttpBody,
|
||||
error_handling::HandleErrorLayer,
|
||||
extract::{FromRef, Multipart, Path, Query, State},
|
||||
http::Uri,
|
||||
http::{Request, StatusCode, Uri},
|
||||
response::{
|
||||
sse::{Event, KeepAlive},
|
||||
Html, IntoResponse, Redirect, Sse,
|
||||
},
|
||||
routing::get,
|
||||
Form, Router,
|
||||
BoxError, Form, Router,
|
||||
};
|
||||
use axum_htmx::{HxRedirect, HxRequest};
|
||||
use axum_oidc::oidc::{self, EmptyAdditionalClaims, OidcApplication, OidcExtractor};
|
||||
use axum_oidc::{
|
||||
error::MiddlewareError, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcClient,
|
||||
OidcLoginLayer,
|
||||
};
|
||||
use futures_util::Stream;
|
||||
use game::Game;
|
||||
use garbage_collector::{start_gc, GarbageCollectorItem};
|
||||
|
|
@ -28,7 +33,9 @@ use sailfish::TemplateOnce;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use stream::{PlayerBroadcastStream, ViewerBroadcastStream};
|
||||
use tokio::sync::RwLock;
|
||||
use tower::{Layer, ServiceBuilder};
|
||||
use tower_http::services::ServeDir;
|
||||
use tower_sessions::{cookie::SameSite, Expiry, MemoryStore, SessionManagerLayer};
|
||||
|
||||
use crate::error::Error;
|
||||
|
||||
|
|
@ -45,16 +52,9 @@ mod question;
|
|||
pub struct AppState {
|
||||
games: Arc<RwLock<HashMap<String, Game>>>,
|
||||
game_expiry: Arc<RwLock<BinaryHeap<GarbageCollectorItem>>>,
|
||||
oidc_application: OidcApplication<EmptyAdditionalClaims>,
|
||||
application_base: String,
|
||||
}
|
||||
|
||||
impl FromRef<AppState> for OidcApplication<EmptyAdditionalClaims> {
|
||||
fn from_ref(input: &AppState) -> Self {
|
||||
input.oidc_application.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
pub async fn main() {
|
||||
dotenvy::dotenv().ok();
|
||||
|
|
@ -70,18 +70,34 @@ pub async fn main() {
|
|||
.map(|x| x.to_owned())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let oidc_application = OidcApplication::<EmptyAdditionalClaims>::create(
|
||||
application_base
|
||||
.parse()
|
||||
.expect("valid APPLICATION_BASE url"),
|
||||
issuer.to_string(),
|
||||
client_id.to_string(),
|
||||
client_secret.to_owned(),
|
||||
scopes.clone(),
|
||||
oidc::Key::generate(),
|
||||
)
|
||||
.await
|
||||
.expect("Oidc Authentication Client");
|
||||
let session_store = MemoryStore::default();
|
||||
let session_service = ServiceBuilder::new()
|
||||
.layer(HandleErrorLayer::new(|_: BoxError| async {
|
||||
StatusCode::BAD_REQUEST
|
||||
}))
|
||||
.layer(SessionManagerLayer::new(session_store).with_same_site(SameSite::Lax));
|
||||
|
||||
let oidc_login_service = ServiceBuilder::new()
|
||||
.layer(HandleErrorLayer::new(|e: MiddlewareError| async {
|
||||
e.into_response()
|
||||
}))
|
||||
.layer(OidcLoginLayer::<EmptyAdditionalClaims>::new());
|
||||
|
||||
let oidc_auth_service = ServiceBuilder::new()
|
||||
.layer(HandleErrorLayer::new(|e: MiddlewareError| async {
|
||||
e.into_response()
|
||||
}))
|
||||
.layer(
|
||||
OidcAuthLayer::<EmptyAdditionalClaims>::discover_client(
|
||||
Uri::from_maybe_shared(application_base.clone()).expect("valid APPLICATION_BASE"),
|
||||
issuer.to_string(),
|
||||
client_id.to_string(),
|
||||
client_secret.to_owned(),
|
||||
scopes.clone(),
|
||||
)
|
||||
.await
|
||||
.expect("OIDC Client"),
|
||||
);
|
||||
|
||||
let game_expiry: Arc<RwLock<BinaryHeap<GarbageCollectorItem>>> =
|
||||
Arc::new(RwLock::new(BinaryHeap::new()));
|
||||
|
|
@ -92,18 +108,20 @@ pub async fn main() {
|
|||
let app_state = AppState {
|
||||
games,
|
||||
game_expiry,
|
||||
oidc_application,
|
||||
application_base,
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", get(handle_index).post(handle_create))
|
||||
.route("/:id", get(handle_player).post(handle_player_answer))
|
||||
.route("/:id/events", get(sse_player))
|
||||
.route("/:id/view", get(handle_view).post(handle_view_next))
|
||||
.route("/:id/view/events", get(sse_view))
|
||||
.layer(oidc_login_service)
|
||||
.route("/:id", get(handle_player).post(handle_player_answer))
|
||||
.route("/:id/events", get(sse_player))
|
||||
.nest_service("/static", ServeDir::new("static"))
|
||||
.with_state(app_state);
|
||||
.with_state(app_state)
|
||||
.layer(oidc_auth_service)
|
||||
.layer(session_service);
|
||||
|
||||
axum::Server::bind(&"[::]:8080".parse().expect("valid listen address"))
|
||||
.serve(app.into_make_service())
|
||||
|
|
@ -111,15 +129,13 @@ pub async fn main() {
|
|||
.expect("axum server");
|
||||
}
|
||||
|
||||
pub async fn handle_index(
|
||||
oidc_extractor: OidcExtractor<EmptyAdditionalClaims>,
|
||||
) -> HandlerResult<impl IntoResponse> {
|
||||
pub async fn handle_index() -> HandlerResult<impl IntoResponse> {
|
||||
Ok(Html(IndexTemplate {}.render_once()?))
|
||||
}
|
||||
|
||||
pub async fn handle_create(
|
||||
State(state): State<AppState>,
|
||||
oidc_extractor: OidcExtractor<EmptyAdditionalClaims>,
|
||||
OidcClaims(claims): OidcClaims<EmptyAdditionalClaims>,
|
||||
mut body: Multipart,
|
||||
) -> HandlerResult<impl IntoResponse> {
|
||||
let mut quiz: Option<Quiz> = None;
|
||||
|
|
@ -137,11 +153,7 @@ pub async fn handle_create(
|
|||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
let game = Game::new(
|
||||
game_id.clone(),
|
||||
oidc_extractor.claims.subject().to_string(),
|
||||
quiz,
|
||||
);
|
||||
let game = Game::new(game_id.clone(), claims.subject().to_string(), quiz);
|
||||
|
||||
let mut games = state.games.write().await;
|
||||
|
||||
|
|
@ -159,12 +171,12 @@ pub async fn handle_view(
|
|||
Path(id): Path<String>,
|
||||
State(state): State<AppState>,
|
||||
HxRequest(htmx): HxRequest,
|
||||
oidc_extractor: OidcExtractor<EmptyAdditionalClaims>,
|
||||
OidcClaims(claims): OidcClaims<EmptyAdditionalClaims>,
|
||||
) -> HandlerResult<impl IntoResponse> {
|
||||
let games = state.games.read().await;
|
||||
let game = games.get(&id).ok_or(Error::NotFound)?;
|
||||
|
||||
if game.owner != oidc_extractor.claims.subject().to_string() {
|
||||
if game.owner != claims.subject().to_string() {
|
||||
return Err(Error::Forbidden);
|
||||
}
|
||||
|
||||
|
|
@ -175,12 +187,12 @@ pub async fn handle_view_next(
|
|||
Path(id): Path<String>,
|
||||
State(state): State<AppState>,
|
||||
HxRequest(htmx): HxRequest,
|
||||
oidc_extractor: OidcExtractor<EmptyAdditionalClaims>,
|
||||
OidcClaims(claims): OidcClaims<EmptyAdditionalClaims>,
|
||||
) -> HandlerResult<impl IntoResponse> {
|
||||
let mut games = state.games.write().await;
|
||||
let game = games.get_mut(&id).ok_or(Error::NotFound)?;
|
||||
|
||||
if game.owner != oidc_extractor.claims.subject().to_string() {
|
||||
if game.owner != claims.subject().to_string() {
|
||||
return Err(Error::Forbidden);
|
||||
}
|
||||
|
||||
|
|
@ -192,12 +204,12 @@ pub async fn handle_view_next(
|
|||
pub async fn sse_view(
|
||||
Path(id): Path<String>,
|
||||
State(state): State<AppState>,
|
||||
oidc_extractor: OidcExtractor<EmptyAdditionalClaims>,
|
||||
OidcClaims(claims): OidcClaims<EmptyAdditionalClaims>,
|
||||
) -> HandlerResult<Sse<impl Stream<Item = Result<Event, Error>>>> {
|
||||
let games = state.games.read().await;
|
||||
let game = games.get(&id).ok_or(Error::NotFound)?;
|
||||
|
||||
if game.owner != oidc_extractor.claims.subject().to_string() {
|
||||
if game.owner != claims.subject().to_string() {
|
||||
return Err(Error::Forbidden);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue