This commit is contained in:
Paul Zinselmeyer 2023-04-21 15:11:37 +02:00
commit 75ed3b861a
4 changed files with 294 additions and 0 deletions

2
.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
/target
/Cargo.lock

17
Cargo.toml Normal file
View file

@ -0,0 +1,17 @@
[package]
name = "axum_oidc"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
axum = "0.6"
axum-extra = {version="0.7", features=["cookie", "cookie-private"]}
cookie = "0.17"
openidconnect = "3.0"
async-trait = "0.1"
serde = "1.0"
thiserror = "1.0"
reqwest = { version="0.11", default_features=false}
serde_json = "1.0"

54
src/error.rs Normal file
View file

@ -0,0 +1,54 @@
use axum::response::{IntoResponse, Redirect};
use axum_extra::extract::PrivateCookieJar;
use openidconnect::{
core::CoreErrorResponseType, url::ParseError, ClaimsVerificationError, DiscoveryError,
SigningError, StandardErrorResponse,
};
use reqwest::StatusCode;
type RequestTokenError = openidconnect::RequestTokenError<
openidconnect::reqwest::Error<reqwest::Error>,
StandardErrorResponse<CoreErrorResponseType>,
>;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("discovery error: {:?}", 0)]
Discovery(#[from] DiscoveryError<openidconnect::reqwest::Error<reqwest::Error>>),
#[error("parse error: {:?}", 0)]
Parse(#[from] ParseError),
#[error("request token error: {:?}", 0)]
RequestToken(#[from] RequestTokenError),
#[error("claims verification error: {:?}", 0)]
ClaimsVerification(#[from] ClaimsVerificationError),
#[error("signing error: {:?}", 0)]
Signing(#[from] SigningError),
#[error("json serialization error: {:?}", 0)]
Json(#[from] serde_json::Error),
#[error("csrf token is invalid")]
CsrfTokenInvalid,
#[error("id token not found")]
IdTokenNotFound,
#[error("access token hash is invalid")]
AccessTokenHashInvalid,
#[error("just a redirect")]
Redirect((PrivateCookieJar, Redirect)),
}
impl IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
match self {
Self::CsrfTokenInvalid => {
{ (StatusCode::BAD_REQUEST, "csrf token is invalid").into_response() }
.into_response()
}
Self::Redirect(redirect) => redirect.into_response(),
_ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(),
}
}
}

221
src/lib.rs Normal file
View file

@ -0,0 +1,221 @@
use async_trait::async_trait;
use axum::{
extract::{FromRef, FromRequestParts, Query},
http::request::Parts,
response::Redirect,
};
use axum_extra::extract::{
cookie::{Cookie, SameSite},
PrivateCookieJar,
};
use cookie::time::{Duration, OffsetDateTime};
use error::Error;
use openidconnect::{
core::{
CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreErrorResponseType,
CoreGenderClaim, CoreJsonWebKey, CoreJsonWebKeyType, CoreJsonWebKeyUse,
CoreJweContentEncryptionAlgorithm, CoreJwsSigningAlgorithm, CoreProviderMetadata,
CoreRevocableToken, CoreRevocationErrorResponse, CoreTokenIntrospectionResponse,
CoreTokenType,
},
reqwest::async_http_client,
AccessTokenHash, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken,
EmptyExtraTokenFields, IdTokenClaims, IdTokenFields, IssuerUrl, Nonce, OAuth2TokenResponse,
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope, StandardErrorResponse,
StandardTokenResponse, TokenResponse,
};
use serde::{Deserialize, Serialize};
pub use cookie::Key;
pub mod error;
const LOGIN_COOKIE_NAME: &str = "OIDC_LOGIN";
pub trait AdditionalClaims: openidconnect::AdditionalClaims + Clone + Sync + Send {}
type OidcTokenResponse<AC> = StandardTokenResponse<
IdTokenFields<
AC,
EmptyExtraTokenFields,
CoreGenderClaim,
CoreJweContentEncryptionAlgorithm,
CoreJwsSigningAlgorithm,
CoreJsonWebKeyType,
>,
CoreTokenType,
>;
pub type OidcClient<AC> = Client<
AC,
CoreAuthDisplay,
CoreGenderClaim,
CoreJweContentEncryptionAlgorithm,
CoreJwsSigningAlgorithm,
CoreJsonWebKeyType,
CoreJsonWebKeyUse,
CoreJsonWebKey,
CoreAuthPrompt,
StandardErrorResponse<CoreErrorResponseType>,
OidcTokenResponse<AC>,
CoreTokenType,
CoreTokenIntrospectionResponse,
CoreRevocableToken,
CoreRevocationErrorResponse,
>;
pub struct OidcApplication {
application_base: String,
issuer: IssuerUrl,
client_id: ClientId,
client_secret: Option<ClientSecret>,
scopes: Vec<String>,
cookie_key: Key,
}
impl OidcApplication {
pub fn new(
application_base: String,
issuer: String,
client_id: String,
client_secret: Option<String>,
scopes: Vec<String>,
cookie_key: Key,
) -> Self {
Self {
application_base,
issuer: IssuerUrl::new(issuer).unwrap(),
client_id: ClientId::new(client_id),
client_secret: client_secret.map(ClientSecret::new),
scopes,
cookie_key,
}
}
async fn create_client<AC: AdditionalClaims>(
&self,
redirect: String,
) -> Result<OidcClient<AC>, Error> {
let provider_metadata =
CoreProviderMetadata::discover_async(self.issuer.clone(), async_http_client).await?;
let client = OidcClient::<AC>::from_provider_metadata(
provider_metadata,
self.client_id.clone(),
self.client_secret.clone(),
)
.set_redirect_uri(RedirectUrl::new(redirect)?);
Ok(client)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmptyAdditionalClaims {}
impl openidconnect::AdditionalClaims for EmptyAdditionalClaims {}
impl AdditionalClaims for EmptyAdditionalClaims {}
pub struct ClaimsExtractor<AC: AdditionalClaims>(pub IdTokenClaims<AC, CoreGenderClaim>);
#[async_trait]
impl<S, AC> FromRequestParts<S> for ClaimsExtractor<AC>
where
S: Send + Sync,
AC: AdditionalClaims,
OidcApplication: FromRef<S>,
{
type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let application: OidcApplication = OidcApplication::from_ref(state);
let client = application
.create_client(format!(
"{}/{}",
application.application_base,
parts.uri.path()
))
.await?;
let mut jar = PrivateCookieJar::from_headers(&parts.headers, application.cookie_key);
let login_session = jar.get(LOGIN_COOKIE_NAME);
let query = Query::<OidcQuery>::from_request_parts(parts, state)
.await
.ok();
if let (Some(login_session), Some(Query(query))) = (login_session, query) {
let login_session: LoginSession = serde_json::from_str(login_session.value())?;
if login_session.csrf_token.secret() != &query.state {
return Err(Error::CsrfTokenInvalid);
}
let token_response = client
.exchange_code(AuthorizationCode::new(query.code.to_string()))
// Set the PKCE code verifier.
.set_pkce_verifier(login_session.pkce_verifier)
.request_async(async_http_client)
.await?;
// Extract the ID token claims after verifying its authenticity and nonce.
let id_token = token_response.id_token().ok_or(Error::IdTokenNotFound)?;
let claims = id_token.claims(&client.id_token_verifier(), &login_session.nonce)?;
// Verify the access token hash to ensure that the access token hasn't been substituted for
// another user's.
if let Some(expected_access_token_hash) = claims.access_token_hash() {
let actual_access_token_hash = AccessTokenHash::from_token(
token_response.access_token(),
&id_token.signing_alg()?,
)?;
if actual_access_token_hash != *expected_access_token_hash {
return Err(Error::AccessTokenHashInvalid);
}
}
Ok(Self(claims.clone()))
} else {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let (auth_url, csrf_token, nonce) = {
let mut auth = client.authorize_url(
CoreAuthenticationFlow::AuthorizationCode,
CsrfToken::new_random,
Nonce::new_random,
);
for scope in application.scopes.iter() {
auth = auth.add_scope(Scope::new(scope.to_string()));
}
auth.set_pkce_challenge(pkce_challenge).url()
};
let login_session = LoginSession {
nonce,
csrf_token,
pkce_verifier,
};
let login_session = serde_json::to_string(&login_session)?;
let mut cookie = Cookie::new(LOGIN_COOKIE_NAME, login_session);
cookie.set_same_site(SameSite::Lax);
cookie.set_secure(true);
cookie.set_http_only(true);
cookie.set_expires(OffsetDateTime::now_utc() + Duration::hours(1));
jar = jar.add(cookie);
Err(Error::Redirect((
jar,
Redirect::temporary(auth_url.as_str()),
)))
}
}
}
#[derive(Debug, Deserialize)]
struct OidcQuery {
code: String,
state: String,
#[allow(dead_code)]
session_state: String,
}
#[derive(Serialize, Deserialize)]
struct LoginSession {
nonce: Nonce,
csrf_token: CsrfToken,
pkce_verifier: PkceCodeVerifier,
}