jwt as extractor
This commit is contained in:
parent
89da8cc07f
commit
3b9438d1f3
4 changed files with 39 additions and 76 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
@ -162,14 +162,12 @@ dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum",
|
"axum",
|
||||||
"axum-extra",
|
"axum-extra",
|
||||||
"futures-util",
|
|
||||||
"jsonwebtoken",
|
"jsonwebtoken",
|
||||||
"openidconnect",
|
"openidconnect",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tower",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
|
@ -18,10 +18,8 @@ serde = "1.0"
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
|
|
||||||
jsonwebtoken = {version="^8.3", optional=true}
|
jsonwebtoken = {version="^8.3", optional=true}
|
||||||
tower = {version="^0.4", optional=true}
|
|
||||||
futures-util = {version="^0.3",optional=true}
|
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = [ "jwt", "oidc" ]
|
default = [ "jwt", "oidc" ]
|
||||||
oidc = [ "openidconnect", "axum-extra" ]
|
oidc = [ "openidconnect", "axum-extra" ]
|
||||||
jwt = [ "tower", "jsonwebtoken", "futures-util", "reqwest/json", "reqwest/rustls-tls", "serde/derive" ]
|
jwt = [ "jsonwebtoken", "reqwest/json", "reqwest/rustls-tls", "serde/derive" ]
|
||||||
|
|
|
@ -65,6 +65,10 @@ pub enum Error {
|
||||||
#[cfg(feature = "jwt")]
|
#[cfg(feature = "jwt")]
|
||||||
#[error("jsonwebtoken: {0}")]
|
#[error("jsonwebtoken: {0}")]
|
||||||
JsonWebToken(#[from] jsonwebtoken::errors::Error),
|
JsonWebToken(#[from] jsonwebtoken::errors::Error),
|
||||||
|
|
||||||
|
#[cfg(feature = "jwt")]
|
||||||
|
#[error("jwt invalid")]
|
||||||
|
JwtInvalid,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntoResponse for Error {
|
impl IntoResponse for Error {
|
||||||
|
@ -78,6 +82,10 @@ impl IntoResponse for Error {
|
||||||
|
|
||||||
#[cfg(feature = "oidc")]
|
#[cfg(feature = "oidc")]
|
||||||
Self::Redirect(redirect) => redirect.into_response(),
|
Self::Redirect(redirect) => redirect.into_response(),
|
||||||
|
|
||||||
|
#[cfg(feature = "jwt")]
|
||||||
|
Self::JwtInvalid => (StatusCode::UNAUTHORIZED, "access token invalid").into_response(),
|
||||||
|
|
||||||
_ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(),
|
_ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
101
src/jwt.rs
101
src/jwt.rs
|
@ -1,18 +1,12 @@
|
||||||
use std::{
|
use std::marker::PhantomData;
|
||||||
marker::PhantomData,
|
|
||||||
task::{Context, Poll},
|
|
||||||
};
|
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Body,
|
extract::{FromRef, FromRequestParts},
|
||||||
http::Request,
|
http::request::Parts,
|
||||||
response::{IntoResponse, Response},
|
|
||||||
};
|
};
|
||||||
use futures_util::future::BoxFuture;
|
|
||||||
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
|
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
|
||||||
use reqwest::StatusCode;
|
use serde::{de::DeserializeOwned, Deserialize};
|
||||||
use serde::Deserialize;
|
|
||||||
use tower::{Layer, Service};
|
|
||||||
|
|
||||||
use crate::error::Error;
|
use crate::error::Error;
|
||||||
|
|
||||||
|
@ -35,10 +29,8 @@ pub struct Claims<A: Clone> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct JwtLayer<A: Clone> {
|
pub struct JwtApplication<A: Clone> {
|
||||||
algorithm: Algorithm,
|
validation: Validation,
|
||||||
issuer: Vec<String>,
|
|
||||||
audience: Vec<String>,
|
|
||||||
pubkey: DecodingKey,
|
pubkey: DecodingKey,
|
||||||
_a: PhantomData<A>,
|
_a: PhantomData<A>,
|
||||||
}
|
}
|
||||||
|
@ -48,7 +40,7 @@ struct IssuerDiscovery {
|
||||||
public_key: String,
|
public_key: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<A: Clone> JwtLayer<A> {
|
impl<A: Clone> JwtApplication<A> {
|
||||||
pub async fn new(issuer: String, audience: String) -> Result<Self, Error> {
|
pub async fn new(issuer: String, audience: String) -> Result<Self, Error> {
|
||||||
let issuer_key = reqwest::get(&issuer)
|
let issuer_key = reqwest::get(&issuer)
|
||||||
.await?
|
.await?
|
||||||
|
@ -63,78 +55,45 @@ impl<A: Clone> JwtLayer<A> {
|
||||||
|
|
||||||
let pubkey = DecodingKey::from_rsa_pem(pem.as_bytes())?;
|
let pubkey = DecodingKey::from_rsa_pem(pem.as_bytes())?;
|
||||||
|
|
||||||
|
let mut validation = Validation::new(Algorithm::RS256);
|
||||||
|
validation.set_issuer(&[issuer]);
|
||||||
|
validation.set_audience(&[audience]);
|
||||||
|
validation.validate_nbf = true;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
algorithm: Algorithm::RS256,
|
validation,
|
||||||
issuer: vec![issuer],
|
|
||||||
audience: vec![audience],
|
|
||||||
pubkey,
|
pubkey,
|
||||||
_a: PhantomData,
|
_a: PhantomData,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, A: Clone> Layer<S> for JwtLayer<A> {
|
#[async_trait]
|
||||||
type Service = JwtService<S, A>;
|
impl<S, A> FromRequestParts<S> for Claims<A>
|
||||||
|
|
||||||
fn layer(&self, inner: S) -> Self::Service {
|
|
||||||
let mut validation = Validation::new(self.algorithm);
|
|
||||||
validation.set_issuer(&self.issuer);
|
|
||||||
validation.set_audience(&self.audience);
|
|
||||||
validation.validate_nbf = true;
|
|
||||||
JwtService {
|
|
||||||
validation,
|
|
||||||
pubkey: self.pubkey.clone(),
|
|
||||||
inner,
|
|
||||||
_a: PhantomData,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct JwtService<S, A: Clone> {
|
|
||||||
validation: Validation,
|
|
||||||
pubkey: DecodingKey,
|
|
||||||
inner: S,
|
|
||||||
_a: PhantomData<A>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S, A: Clone> Service<Request<Body>> for JwtService<S, A>
|
|
||||||
where
|
where
|
||||||
S: Service<Request<Body>, Response = Response> + Send + 'static,
|
S: Send + Sync,
|
||||||
S::Future: Send + 'static,
|
A: Clone + DeserializeOwned,
|
||||||
A: Clone + for<'a> Deserialize<'a> + 'static + Sync + Send,
|
JwtApplication<A>: FromRef<S>,
|
||||||
{
|
{
|
||||||
type Response = S::Response;
|
type Rejection = Error;
|
||||||
type Error = S::Error;
|
|
||||||
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
|
|
||||||
|
|
||||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||||
self.inner.poll_ready(cx)
|
let application: JwtApplication<A> = JwtApplication::from_ref(state);
|
||||||
}
|
|
||||||
|
|
||||||
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
|
let token = parts
|
||||||
let token = req
|
.headers
|
||||||
.headers()
|
|
||||||
.get("Authorization")
|
.get("Authorization")
|
||||||
.and_then(|x| x.to_str().ok())
|
.and_then(|x| x.to_str().ok())
|
||||||
.map(|x| x.chars().skip(7).collect::<String>());
|
.map(|x| x.chars().skip(7).collect::<String>());
|
||||||
|
|
||||||
let token =
|
let token = token.and_then(|x| {
|
||||||
token.and_then(|x| decode::<Claims<A>>(&x, &self.pubkey, &self.validation).ok());
|
decode::<Claims<A>>(&x, &application.pubkey, &application.validation).ok()
|
||||||
let token_exists = token.is_some();
|
});
|
||||||
|
|
||||||
if let Some(token) = token {
|
if let Some(token) = token {
|
||||||
req.extensions_mut().insert(token.claims);
|
Ok(token.claims)
|
||||||
}
|
|
||||||
|
|
||||||
let future = self.inner.call(req);
|
|
||||||
Box::pin(async move {
|
|
||||||
if token_exists {
|
|
||||||
let response: Response = future.await?;
|
|
||||||
Ok(response)
|
|
||||||
} else {
|
} else {
|
||||||
Ok((StatusCode::UNAUTHORIZED, "access token invalid").into_response())
|
Err(Error::JwtInvalid)
|
||||||
}
|
}
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Reference in a new issue