integrated with axum_oidc

This commit is contained in:
Paul Zinselmeyer 2023-04-21 16:10:14 +02:00
parent d635797394
commit e08fa51637
Signed by: pfzetto
GPG key ID: 4EEF46A5B276E648
5 changed files with 279 additions and 307 deletions

234
Cargo.lock generated
View file

@ -3,10 +3,45 @@
version = 3 version = 3
[[package]] [[package]]
name = "aho-corasick" name = "aead"
version = "0.7.20" version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0"
dependencies = [
"crypto-common",
"generic-array",
]
[[package]]
name = "aes"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "433cfd6710c9986c576a25ca913c39d66a6474107b406f34f91d4a8923395241"
dependencies = [
"cfg-if",
"cipher",
"cpufeatures",
]
[[package]]
name = "aes-gcm"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82e1366e0c69c9f927b1fa5ce2c7bf9eafc8f9268c0b9800729e8b267612447c"
dependencies = [
"aead",
"aes",
"cipher",
"ctr",
"ghash",
"subtle",
]
[[package]]
name = "aho-corasick"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67fc08ce920c31afb70f013dcce1bfc3a3195de6a228474e45e1f145b36f8d04"
dependencies = [ dependencies = [
"memchr", "memchr",
] ]
@ -88,6 +123,29 @@ dependencies = [
"tower-service", "tower-service",
] ]
[[package]]
name = "axum-extra"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "febf23ab04509bd7672e6abe76bd8277af31b679e89fa5ffc6087dc289a448a3"
dependencies = [
"axum",
"axum-core",
"bytes",
"cookie",
"futures-util",
"http",
"http-body",
"mime",
"pin-project-lite",
"serde",
"tokio",
"tower",
"tower-http",
"tower-layer",
"tower-service",
]
[[package]] [[package]]
name = "axum-macros" name = "axum-macros"
version = "0.3.7" version = "0.3.7"
@ -100,6 +158,22 @@ dependencies = [
"syn 2.0.15", "syn 2.0.15",
] ]
[[package]]
name = "axum_oidc"
version = "0.1.0"
source = "git+https://git.zettoit.eu/pfz4/axum_oidc#75ed3b861a85cd18d8473a65cc6a90bc38528527"
dependencies = [
"async-trait",
"axum",
"axum-extra",
"cookie",
"openidconnect",
"reqwest",
"serde",
"serde_json",
"thiserror",
]
[[package]] [[package]]
name = "base16ct" name = "base16ct"
version = "0.1.1" version = "0.1.1"
@ -129,15 +203,14 @@ name = "bin"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"axum", "axum",
"axum_oidc",
"chrono", "chrono",
"dotenvy", "dotenvy",
"futures-util", "futures-util",
"markdown", "markdown",
"openidconnect",
"parse_duration", "parse_duration",
"rand", "rand",
"render", "render",
"reqwest",
"serde", "serde",
"serde_cbor", "serde_cbor",
"thiserror", "thiserror",
@ -201,11 +274,21 @@ dependencies = [
"num-integer", "num-integer",
"num-traits", "num-traits",
"serde", "serde",
"time", "time 0.1.45",
"wasm-bindgen", "wasm-bindgen",
"winapi", "winapi",
] ]
[[package]]
name = "cipher"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad"
dependencies = [
"crypto-common",
"inout",
]
[[package]] [[package]]
name = "codespan-reporting" name = "codespan-reporting"
version = "0.11.1" version = "0.11.1"
@ -222,6 +305,21 @@ version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "520fbf3c07483f94e3e3ca9d0cfd913d7718ef2483d2cfd91c0d9e91474ab913" checksum = "520fbf3c07483f94e3e3ca9d0cfd913d7718ef2483d2cfd91c0d9e91474ab913"
[[package]]
name = "cookie"
version = "0.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7efb37c3e1ccb1ff97164ad95ac1606e8ccd35b3fa0a7d99a304c7f4a428cc24"
dependencies = [
"aes-gcm",
"base64 0.21.0",
"percent-encoding",
"rand",
"subtle",
"time 0.3.20",
"version_check",
]
[[package]] [[package]]
name = "core-foundation-sys" name = "core-foundation-sys"
version = "0.8.4" version = "0.8.4"
@ -230,9 +328,9 @@ checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa"
[[package]] [[package]]
name = "cpufeatures" name = "cpufeatures"
version = "0.2.6" version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "280a9f2d8b3a38871a3c8a46fb80db65e5e5ed97da80c4d08bf27fb63e35e181" checksum = "3e4c1eaa2012c47becbbad2ab175484c2a84d1185b566fb2cc5b8707343dfe58"
dependencies = [ dependencies = [
"libc", "libc",
] ]
@ -256,9 +354,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
dependencies = [ dependencies = [
"generic-array", "generic-array",
"rand_core",
"typenum", "typenum",
] ]
[[package]]
name = "ctr"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835"
dependencies = [
"cipher",
]
[[package]] [[package]]
name = "cxx" name = "cxx"
version = "1.0.94" version = "1.0.94"
@ -530,6 +638,16 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "ghash"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d930750de5717d2dd0b8c0d42c076c0e884c81a73e6cab859bbd2339c71e3e40"
dependencies = [
"opaque-debug",
"polyval",
]
[[package]] [[package]]
name = "group" name = "group"
version = "0.12.1" version = "0.12.1"
@ -652,6 +770,12 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
] ]
[[package]]
name = "http-range-header"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29"
[[package]] [[package]]
name = "httparse" name = "httparse"
version = "1.8.0" version = "1.8.0"
@ -751,6 +875,15 @@ dependencies = [
"hashbrown", "hashbrown",
] ]
[[package]]
name = "inout"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5"
dependencies = [
"generic-array",
]
[[package]] [[package]]
name = "ipnet" name = "ipnet"
version = "2.7.2" version = "2.7.2"
@ -792,9 +925,9 @@ dependencies = [
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.141" version = "0.2.142"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5" checksum = "6a987beff54b60ffa6d51982e1aa1146bc42f19bd26be28b0586f252fccf5317"
[[package]] [[package]]
name = "libm" name = "libm"
@ -1002,6 +1135,12 @@ version = "1.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3"
[[package]]
name = "opaque-debug"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
[[package]] [[package]]
name = "openidconnect" name = "openidconnect"
version = "3.0.0" version = "3.0.0"
@ -1173,6 +1312,18 @@ dependencies = [
"spki", "spki",
] ]
[[package]]
name = "polyval"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ef234e08c11dfcb2e56f79fd70f6f2eb7f025c0ce2333e82f4f0518ecad30c6"
dependencies = [
"cfg-if",
"cpufeatures",
"opaque-debug",
"universal-hash",
]
[[package]] [[package]]
name = "ppv-lite86" name = "ppv-lite86"
version = "0.2.17" version = "0.2.17"
@ -1262,9 +1413,9 @@ dependencies = [
[[package]] [[package]]
name = "regex" name = "regex"
version = "1.7.3" version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b1f693b24f6ac912f4893ef08244d70b6067480d2f1a46e950c9691e6749d1d" checksum = "af83e617f331cc6ae2da5443c602dfa5af81e517212d9d611a5b3ba1777b5370"
dependencies = [ dependencies = [
"aho-corasick", "aho-corasick",
"memchr", "memchr",
@ -1273,9 +1424,9 @@ dependencies = [
[[package]] [[package]]
name = "regex-syntax" name = "regex-syntax"
version = "0.6.29" version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" checksum = "a5996294f19bd3aae0453a862ad728f60e6600695733dd5df01da90c54363a3c"
[[package]] [[package]]
name = "render" name = "render"
@ -1716,6 +1867,33 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "time"
version = "0.3.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd0cbfecb4d19b5ea75bb31ad904eb5b9fa13f21079c3b92017ebdf4999a5890"
dependencies = [
"itoa",
"serde",
"time-core",
"time-macros",
]
[[package]]
name = "time-core"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd"
[[package]]
name = "time-macros"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd80a657e71da814b8e5d60d3374fc6d35045062245d80224748ae522dd76f36"
dependencies = [
"time-core",
]
[[package]] [[package]]
name = "tinyvec" name = "tinyvec"
version = "1.6.0" version = "1.6.0"
@ -1802,6 +1980,24 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "tower-http"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d1d42a9b3f3ec46ba828e8d376aec14592ea199f70a06a548587ecd1c4ab658"
dependencies = [
"bitflags",
"bytes",
"futures-core",
"futures-util",
"http",
"http-body",
"http-range-header",
"pin-project-lite",
"tower-layer",
"tower-service",
]
[[package]] [[package]]
name = "tower-layer" name = "tower-layer"
version = "0.3.2" version = "0.3.2"
@ -1874,6 +2070,16 @@ version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b"
[[package]]
name = "universal-hash"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d3160b73c9a19f7e2939a2fdad446c57c1bbbbf4d919d3213ff1267a580d8b5"
dependencies = [
"crypto-common",
"subtle",
]
[[package]] [[package]]
name = "untrusted" name = "untrusted"
version = "0.7.1" version = "0.7.1"

View file

@ -12,13 +12,11 @@ futures-util = "0.3"
axum = {version="0.6", features=["macros", "headers"]} axum = {version="0.6", features=["macros", "headers"]}
serde = "1.0" serde = "1.0"
serde_cbor = "0.11" serde_cbor = "0.11"
openidconnect = "3.0"
render = { git="https://github.com/render-rs/render.rs" } render = { git="https://github.com/render-rs/render.rs" }
thiserror = "1.0.40" thiserror = "1.0.40"
rand = "0.8.5" rand = "0.8.5"
dotenvy = "0.15" dotenvy = "0.15"
reqwest = { version="0.11", default_features=false}
markdown = "0.3.0" markdown = "0.3.0"
chrono = { version="0.4", features=["serde"]} chrono = { version="0.4", features=["serde"]}
parse_duration = "2.1" parse_duration = "2.1"
axum_oidc = {git="https://git.zettoit.eu/pfz4/axum_oidc"}

View file

@ -1,25 +1,18 @@
use std::{ use std::{collections::BTreeMap, env, str::FromStr, sync::Arc, time::Duration};
collections::{BTreeMap, HashMap},
env,
fmt::LowerExp,
str::FromStr,
sync::Arc,
time::Duration,
};
use axum::{ use axum::{
body::StreamBody, body::StreamBody,
extract::{BodyStream, Path, Query, State}, extract::{BodyStream, FromRef, Path, Query, State},
headers::ContentType, headers::ContentType,
http::{header, HeaderMap, StatusCode}, http::{header, HeaderMap, StatusCode},
response::{Html, IntoResponse, Redirect}, response::{Html, IntoResponse, Redirect},
routing::get, routing::get,
Router, TypedHeader, Router, TypedHeader,
}; };
use axum_oidc::{ClaimsExtractor, EmptyAdditionalClaims, Key, OidcApplication};
use chrono::{NaiveDateTime, Utc}; use chrono::{NaiveDateTime, Utc};
use futures_util::StreamExt; use futures_util::StreamExt;
use metadata::Metadata; use metadata::Metadata;
use openid::Login;
use render::{html, raw}; use render::{html, raw};
use serde::Deserialize; use serde::Deserialize;
use tokio::{ use tokio::{
@ -30,7 +23,6 @@ use tokio::{
use tokio_util::io::ReaderStream; use tokio_util::io::ReaderStream;
pub mod metadata; pub mod metadata;
pub mod openid;
// RFC 7230 section 3.1.1 // RFC 7230 section 3.1.1
// It is RECOMMENDED that all HTTP senders and recipients // It is RECOMMENDED that all HTTP senders and recipients
@ -77,8 +69,21 @@ pub struct AppState {
client_id: String, client_id: String,
client_secret: Option<String>, client_secret: Option<String>,
scopes: Vec<String>, scopes: Vec<String>,
logins: Arc<Mutex<HashMap<String, Login>>>,
expire: Arc<Mutex<BTreeMap<NaiveDateTime, String>>>, expire: Arc<Mutex<BTreeMap<NaiveDateTime, String>>>,
key: Key,
}
impl FromRef<AppState> for OidcApplication {
fn from_ref(input: &AppState) -> Self {
OidcApplication::new(
input.application_base.to_string(),
input.issuer.to_string(),
input.client_id.to_string(),
input.client_secret.to_owned(),
input.scopes.clone(),
input.key.clone(),
)
}
} }
#[tokio::main] #[tokio::main]
@ -109,15 +114,14 @@ async fn main() {
client_id, client_id,
client_secret, client_secret,
scopes, scopes,
logins: Arc::new(Mutex::new(HashMap::new())),
expire: expire.clone(), expire: expire.clone(),
key: Key::generate(),
}; };
tokio::spawn(async move { expire_thread("data".to_string(), expire).await }); tokio::spawn(async move { expire_thread("data".to_string(), expire).await });
let app = Router::new() let app = Router::new()
.route("/", get(openid::handle_login)) .route("/", get(get_index))
.route("/login/:id", get(openid::handle_callback))
.route("/:id", get(get_item).post(post_item).put(post_item)) .route("/:id", get(get_item).post(post_item).put(post_item))
.with_state(state); .with_state(state);
axum::Server::bind(&"0.0.0.0:3000".parse().unwrap()) axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
@ -126,6 +130,17 @@ async fn main() {
.unwrap(); .unwrap();
} }
async fn get_index(
State(app_state): State<AppState>,
ClaimsExtractor(claims): ClaimsExtractor<EmptyAdditionalClaims>,
) -> impl IntoResponse {
let subject = claims.subject().to_string();
let (id, _) = Metadata::create(&app_state.path, subject).await.unwrap();
Redirect::temporary(&format!("{}{}", app_state.application_base, id))
}
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct PostQuery { pub struct PostQuery {
ttl: Option<String>, ttl: Option<String>,
@ -159,13 +174,11 @@ async fn post_item(
} else { } else {
Ok(Utc::now().naive_utc() + chrono::Duration::days(30)) Ok(Utc::now().naive_utc() + chrono::Duration::days(30))
}?; }?;
{
app_state.expire.lock().await.insert(expires_at, id.clone());
}
let mut data_file = File::create(&format!("{}/{}.data", app_state.path, &id)) let mut data_file = File::create(&format!("{}/{}.data", app_state.path, &id))
.await .await
.unwrap(); .unwrap();
while let Some(chunk) = stream.next().await { while let Some(chunk) = stream.next().await {
let buf = chunk.map(|x| x.to_vec()).unwrap_or_default(); let buf = chunk.map(|x| x.to_vec()).unwrap_or_default();
data_file.write_all(&buf).await.unwrap(); data_file.write_all(&buf).await.unwrap();
@ -173,9 +186,10 @@ async fn post_item(
metadata.mimetype = content_type.map(|x| x.0.to_string()); metadata.mimetype = content_type.map(|x| x.0.to_string());
metadata.ttl = Some(expires_at); metadata.ttl = Some(expires_at);
metadata.to_file(&app_state.path, &id).await.unwrap(); metadata.to_file(&app_state.path, &id).await.unwrap();
app_state.expire.lock().await.insert(expires_at, id);
Ok((StatusCode::CREATED, "OK")) Ok((StatusCode::CREATED, "OK"))
} else { } else {
Err(Error::DataFileExists) Err(Error::DataFileExists)
@ -279,9 +293,8 @@ async fn expire_thread(path: String, expire_dates: Arc<Mutex<BTreeMap<NaiveDateT
fs::remove_file(&format!("{}/{}.meta", &path, &id)) fs::remove_file(&format!("{}/{}.meta", &path, &id))
.await .await
.unwrap(); .unwrap();
fs::remove_file(&format!("{}/{}.data", &path, &id)) let _ = fs::remove_file(&format!("{}/{}.data", &path, &id)).await;
.await
.unwrap();
to_delete.push(*expire); to_delete.push(*expire);
} else { } else {
break; break;

View file

@ -1,4 +1,5 @@
use chrono::NaiveDateTime; use chrono::{NaiveDateTime, Utc};
use rand::{distributions::Alphanumeric, Rng};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::{ use tokio::{
fs::File, fs::File,
@ -19,9 +20,30 @@ pub struct Metadata {
pub subject: String, pub subject: String,
pub ttl: Option<NaiveDateTime>, pub ttl: Option<NaiveDateTime>,
pub mimetype: Option<String>, pub mimetype: Option<String>,
pub sha256: Option<String>,
pub sha512: Option<String>,
} }
impl Metadata { impl Metadata {
pub async fn create(path: &str, subject: String) -> Result<(String, Self), Error> {
let id = rand::thread_rng()
.sample_iter(Alphanumeric)
.take(8)
.map(char::from)
.collect::<String>();
let metadata = Metadata {
subject,
mimetype: None,
ttl: Some(Utc::now().naive_utc() + chrono::Duration::days(1)),
sha256: None,
sha512: None,
};
metadata.to_file(path, &id).await.unwrap();
Ok((id, metadata))
}
pub async fn from_file(path: &str, id: &str) -> Result<Self, Error> { pub async fn from_file(path: &str, id: &str) -> Result<Self, Error> {
let mut metadata_file = File::open(&format!("{}/{}.meta", path, id)).await?; let mut metadata_file = File::open(&format!("{}/{}.meta", path, id)).await?;
let mut metadata = Vec::new(); let mut metadata = Vec::new();

View file

@ -1,267 +0,0 @@
use axum::{
extract::{Path, Query, State},
http::StatusCode,
response::{IntoResponse, Redirect},
};
use openidconnect::{
core::{CoreAuthenticationFlow, CoreClient, CoreErrorResponseType, CoreProviderMetadata},
reqwest::async_http_client,
url::ParseError,
AccessTokenHash, AuthorizationCode, ClaimsVerificationError, ClientId, ClientSecret, CsrfToken,
DiscoveryError, IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier,
RedirectUrl, RequestTokenError, Scope, SigningError, StandardErrorResponse, TokenResponse,
};
use rand::{distributions::Alphanumeric, Rng};
use serde::Deserialize;
use crate::{metadata::Metadata, AppState};
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("discovery error: {:?}", 0)]
Discovery(#[from] DiscoveryError<openidconnect::reqwest::Error<reqwest::Error>>),
#[error("parse error: {:?}", 0)]
Parse(#[from] ParseError),
#[error("request token error: {:?}", 0)]
RequestToken(
#[from]
RequestTokenError<
openidconnect::reqwest::Error<reqwest::Error>,
StandardErrorResponse<CoreErrorResponseType>,
>,
),
#[error("claims verification error: {:?}", 0)]
ClaimsVerification(#[from] ClaimsVerificationError),
#[error("signing error: {:?}", 0)]
SigningError(#[from] SigningError),
#[error("id token not found")]
IdTokenNotFound,
}
impl IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
println!("openid error: {:?}", self);
(StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response()
}
}
#[derive(Debug, Deserialize)]
pub struct OidcBody {
pub code: String,
pub state: String,
pub session_state: String,
}
#[derive(Debug, Clone)]
pub struct Login {
csrf_token: String,
nonce: String,
pkce_verifier: String,
}
pub async fn handle_callback(
State(state): State<AppState>,
Path(auth_id): Path<String>,
Query(body): Query<OidcBody>,
) -> Result<impl IntoResponse, Error> {
let auth_instance = {
let logins = state.logins.lock().await;
logins.get(&auth_id).cloned().unwrap()
};
let client = create_oidc_client(
state.issuer.clone(),
state.client_id.clone(),
state.client_secret.clone(),
&state.application_base,
&auth_id,
)
.await?;
if auth_instance.csrf_token != body.state {
return Ok((StatusCode::BAD_REQUEST, "csrf token is invalid").into_response());
}
let pkce_verifier = PkceCodeVerifier::new(auth_instance.pkce_verifier.clone());
let nonce = Nonce::new(auth_instance.nonce.clone());
let token_response = client
.exchange_code(AuthorizationCode::new(body.code.to_string()))
// Set the PKCE code verifier.
.set_pkce_verifier(pkce_verifier)
.request_async(async_http_client)
.await?;
// Extract the ID token claims after verifying its authenticity and nonce.
let id_token = token_response.id_token().ok_or(Error::IdTokenNotFound)?;
let claims = id_token.claims(&client.id_token_verifier(), &nonce)?;
// Verify the access token hash to ensure that the access token hasn't been substituted for
// another user's.
if let Some(expected_access_token_hash) = claims.access_token_hash() {
let actual_access_token_hash =
AccessTokenHash::from_token(token_response.access_token(), &id_token.signing_alg()?)?;
if actual_access_token_hash != *expected_access_token_hash {
return Ok((StatusCode::BAD_REQUEST, "access token hash is invalid").into_response());
}
}
//let mut oidc_user = oidc_user::Entity::find()
// .filter(
// oidc_user::Column::OidcClientId
// .eq(oidc_client.id)
// .and(oidc_user::Column::Subject.eq(claims.subject().as_str())),
// )
// .one(&state.db)
// .await?
// .map(|x| x.into_active_model())
// .unwrap_or_default();
//oidc_user.oidc_client_id = ActiveValue::Set(oidc_client.id);
//oidc_user.subject = ActiveValue::Set(claims.subject().to_string());
//oidc_user.email = ActiveValue::Set(claims.email().map(|x| x.to_string()).unwrap_or_default());
//oidc_user.username = ActiveValue::Set(
// claims
// .preferred_username()
// .map(|x| x.to_string())
// .unwrap_or_default(),
//);
//oidc_user.given_name = ActiveValue::Set(
// claims
// .given_name()
// .and_then(|x| x.get(None).map(|x| x.to_string()))
// .unwrap_or_default(),
//);
//oidc_user.middle_name = ActiveValue::Set(
// claims
// .middle_name()
// .and_then(|x| x.get(None).map(|x| x.to_string()))
// .unwrap_or_default(),
//);
//oidc_user.family_name = ActiveValue::Set(
// claims
// .family_name()
// .and_then(|x| x.get(None).map(|x| x.to_string()))
// .unwrap_or_default(),
//);
//oidc_user.locale = ActiveValue::Set(claims.locale().map(|x| x.to_string()).unwrap_or_default());
//oidc_user.zoneinfo =
// ActiveValue::Set(claims.zoneinfo().map(|x| x.to_string()).unwrap_or_default());
//let oidc_user = if oidc_user.id.is_unchanged() {
// oidc_user.update(&state.db).await?
//} else {
// oidc_user.insert(&state.db).await?
//};
//let instance = if form.multiple_submissions == 0 {
// DatabaseInstance::from_userid(
// state.db.clone(),
// state.submit_producer.clone(),
// oidc_user.id,
// form.id,
// )
// .await?
//} else {
// None
//};
//let instance = match instance {
// Some(x) => x,
// None => {
// DatabaseInstance::new(
// state.db.clone(),
// state.submit_producer.clone(),
// form.id,
// Some(oidc_user.id),
// )
// .await?
// }
//};
{
state.logins.lock().await.remove(&auth_id);
}
let subject = claims.subject().to_string();
let id = rand::thread_rng()
.sample_iter(Alphanumeric)
.take(8)
.map(char::from)
.collect::<String>();
let metadata = Metadata {
subject,
mimetype: None,
ttl: None,
};
metadata.to_file(&state.path, &id).await.unwrap();
Ok((Redirect::temporary(&format!("{}{}", state.application_base, id))).into_response())
}
pub async fn handle_login(State(state): State<AppState>) -> Result<impl IntoResponse, Error> {
let auth_id = rand::thread_rng()
.sample_iter(Alphanumeric)
.take(32)
.map(char::from)
.collect::<String>();
let client = create_oidc_client(
state.issuer.clone(),
state.client_id.clone(),
state.client_secret.clone(),
&state.application_base,
&auth_id,
)
.await?;
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let (auth_url, csrf_token, nonce) = {
let mut auth = client.authorize_url(
CoreAuthenticationFlow::AuthorizationCode,
CsrfToken::new_random,
Nonce::new_random,
);
for scope in state.scopes.iter() {
auth = auth.add_scope(Scope::new(scope.to_string()));
}
auth.set_pkce_challenge(pkce_challenge).url()
};
{
let mut logins = state.logins.lock().await;
logins.insert(
auth_id,
Login {
csrf_token: csrf_token.secret().to_string(),
nonce: nonce.secret().to_string(),
pkce_verifier: pkce_verifier.secret().to_string(),
},
);
}
Ok(Redirect::temporary(auth_url.as_str()).into_response())
}
async fn create_oidc_client(
issuer: String,
client_id: String,
client_secret: Option<String>,
application_base: &str,
auth_id: &str,
) -> Result<CoreClient, Error> {
let provider_metadata =
CoreProviderMetadata::discover_async(IssuerUrl::new(issuer)?, async_http_client).await?;
let client = CoreClient::from_provider_metadata(
provider_metadata,
ClientId::new(client_id.clone()),
client_secret.map(ClientSecret::new),
)
.set_redirect_uri(RedirectUrl::new(format!(
"{}/login/{}",
application_base, auth_id
))?);
Ok(client)
}