jwt as extractor

This commit is contained in:
Paul Zinselmeyer 2023-10-20 00:26:59 +02:00
parent 89da8cc07f
commit 3b9438d1f3
Signed by: pfzetto
GPG key ID: 4EEF46A5B276E648
4 changed files with 39 additions and 76 deletions

2
Cargo.lock generated
View file

@ -162,14 +162,12 @@ dependencies = [
"async-trait", "async-trait",
"axum", "axum",
"axum-extra", "axum-extra",
"futures-util",
"jsonwebtoken", "jsonwebtoken",
"openidconnect", "openidconnect",
"reqwest", "reqwest",
"serde", "serde",
"serde_json", "serde_json",
"thiserror", "thiserror",
"tower",
] ]
[[package]] [[package]]

View file

@ -18,10 +18,8 @@ serde = "1.0"
serde_json = "1.0" serde_json = "1.0"
jsonwebtoken = {version="^8.3", optional=true} jsonwebtoken = {version="^8.3", optional=true}
tower = {version="^0.4", optional=true}
futures-util = {version="^0.3",optional=true}
[features] [features]
default = [ "jwt", "oidc" ] default = [ "jwt", "oidc" ]
oidc = [ "openidconnect", "axum-extra" ] oidc = [ "openidconnect", "axum-extra" ]
jwt = [ "tower", "jsonwebtoken", "futures-util", "reqwest/json", "reqwest/rustls-tls", "serde/derive" ] jwt = [ "jsonwebtoken", "reqwest/json", "reqwest/rustls-tls", "serde/derive" ]

View file

@ -65,6 +65,10 @@ pub enum Error {
#[cfg(feature = "jwt")] #[cfg(feature = "jwt")]
#[error("jsonwebtoken: {0}")] #[error("jsonwebtoken: {0}")]
JsonWebToken(#[from] jsonwebtoken::errors::Error), JsonWebToken(#[from] jsonwebtoken::errors::Error),
#[cfg(feature = "jwt")]
#[error("jwt invalid")]
JwtInvalid,
} }
impl IntoResponse for Error { impl IntoResponse for Error {
@ -78,6 +82,10 @@ impl IntoResponse for Error {
#[cfg(feature = "oidc")] #[cfg(feature = "oidc")]
Self::Redirect(redirect) => redirect.into_response(), Self::Redirect(redirect) => redirect.into_response(),
#[cfg(feature = "jwt")]
Self::JwtInvalid => (StatusCode::UNAUTHORIZED, "access token invalid").into_response(),
_ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(), _ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(),
} }
} }

View file

@ -1,18 +1,12 @@
use std::{ use std::marker::PhantomData;
marker::PhantomData,
task::{Context, Poll},
};
use async_trait::async_trait;
use axum::{ use axum::{
body::Body, extract::{FromRef, FromRequestParts},
http::Request, http::request::Parts,
response::{IntoResponse, Response},
}; };
use futures_util::future::BoxFuture;
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation}; use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use reqwest::StatusCode; use serde::{de::DeserializeOwned, Deserialize};
use serde::Deserialize;
use tower::{Layer, Service};
use crate::error::Error; use crate::error::Error;
@ -35,10 +29,8 @@ pub struct Claims<A: Clone> {
} }
#[derive(Clone)] #[derive(Clone)]
pub struct JwtLayer<A: Clone> { pub struct JwtApplication<A: Clone> {
algorithm: Algorithm, validation: Validation,
issuer: Vec<String>,
audience: Vec<String>,
pubkey: DecodingKey, pubkey: DecodingKey,
_a: PhantomData<A>, _a: PhantomData<A>,
} }
@ -48,7 +40,7 @@ struct IssuerDiscovery {
public_key: String, public_key: String,
} }
impl<A: Clone> JwtLayer<A> { impl<A: Clone> JwtApplication<A> {
pub async fn new(issuer: String, audience: String) -> Result<Self, Error> { pub async fn new(issuer: String, audience: String) -> Result<Self, Error> {
let issuer_key = reqwest::get(&issuer) let issuer_key = reqwest::get(&issuer)
.await? .await?
@ -63,78 +55,45 @@ impl<A: Clone> JwtLayer<A> {
let pubkey = DecodingKey::from_rsa_pem(pem.as_bytes())?; let pubkey = DecodingKey::from_rsa_pem(pem.as_bytes())?;
let mut validation = Validation::new(Algorithm::RS256);
validation.set_issuer(&[issuer]);
validation.set_audience(&[audience]);
validation.validate_nbf = true;
Ok(Self { Ok(Self {
algorithm: Algorithm::RS256, validation,
issuer: vec![issuer],
audience: vec![audience],
pubkey, pubkey,
_a: PhantomData, _a: PhantomData,
}) })
} }
} }
impl<S, A: Clone> Layer<S> for JwtLayer<A> { #[async_trait]
type Service = JwtService<S, A>; impl<S, A> FromRequestParts<S> for Claims<A>
fn layer(&self, inner: S) -> Self::Service {
let mut validation = Validation::new(self.algorithm);
validation.set_issuer(&self.issuer);
validation.set_audience(&self.audience);
validation.validate_nbf = true;
JwtService {
validation,
pubkey: self.pubkey.clone(),
inner,
_a: PhantomData,
}
}
}
#[derive(Clone)]
pub struct JwtService<S, A: Clone> {
validation: Validation,
pubkey: DecodingKey,
inner: S,
_a: PhantomData<A>,
}
impl<S, A: Clone> Service<Request<Body>> for JwtService<S, A>
where where
S: Service<Request<Body>, Response = Response> + Send + 'static, S: Send + Sync,
S::Future: Send + 'static, A: Clone + DeserializeOwned,
A: Clone + for<'a> Deserialize<'a> + 'static + Sync + Send, JwtApplication<A>: FromRef<S>,
{ {
type Response = S::Response; type Rejection = Error;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
self.inner.poll_ready(cx) let application: JwtApplication<A> = JwtApplication::from_ref(state);
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future { let token = parts
let token = req .headers
.headers()
.get("Authorization") .get("Authorization")
.and_then(|x| x.to_str().ok()) .and_then(|x| x.to_str().ok())
.map(|x| x.chars().skip(7).collect::<String>()); .map(|x| x.chars().skip(7).collect::<String>());
let token = let token = token.and_then(|x| {
token.and_then(|x| decode::<Claims<A>>(&x, &self.pubkey, &self.validation).ok()); decode::<Claims<A>>(&x, &application.pubkey, &application.validation).ok()
let token_exists = token.is_some(); });
if let Some(token) = token { if let Some(token) = token {
req.extensions_mut().insert(token.claims); Ok(token.claims)
} else {
Err(Error::JwtInvalid)
} }
let future = self.inner.call(req);
Box::pin(async move {
if token_exists {
let response: Response = future.await?;
Ok(response)
} else {
Ok((StatusCode::UNAUTHORIZED, "access token invalid").into_response())
}
})
} }
} }