init
This commit is contained in:
commit
75ed3b861a
4 changed files with 294 additions and 0 deletions
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
/target
|
||||||
|
/Cargo.lock
|
17
Cargo.toml
Normal file
17
Cargo.toml
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
[package]
|
||||||
|
name = "axum_oidc"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
axum = "0.6"
|
||||||
|
axum-extra = {version="0.7", features=["cookie", "cookie-private"]}
|
||||||
|
cookie = "0.17"
|
||||||
|
openidconnect = "3.0"
|
||||||
|
async-trait = "0.1"
|
||||||
|
serde = "1.0"
|
||||||
|
thiserror = "1.0"
|
||||||
|
reqwest = { version="0.11", default_features=false}
|
||||||
|
serde_json = "1.0"
|
54
src/error.rs
Normal file
54
src/error.rs
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
use axum::response::{IntoResponse, Redirect};
|
||||||
|
use axum_extra::extract::PrivateCookieJar;
|
||||||
|
use openidconnect::{
|
||||||
|
core::CoreErrorResponseType, url::ParseError, ClaimsVerificationError, DiscoveryError,
|
||||||
|
SigningError, StandardErrorResponse,
|
||||||
|
};
|
||||||
|
use reqwest::StatusCode;
|
||||||
|
|
||||||
|
type RequestTokenError = openidconnect::RequestTokenError<
|
||||||
|
openidconnect::reqwest::Error<reqwest::Error>,
|
||||||
|
StandardErrorResponse<CoreErrorResponseType>,
|
||||||
|
>;
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum Error {
|
||||||
|
#[error("discovery error: {:?}", 0)]
|
||||||
|
Discovery(#[from] DiscoveryError<openidconnect::reqwest::Error<reqwest::Error>>),
|
||||||
|
#[error("parse error: {:?}", 0)]
|
||||||
|
Parse(#[from] ParseError),
|
||||||
|
#[error("request token error: {:?}", 0)]
|
||||||
|
RequestToken(#[from] RequestTokenError),
|
||||||
|
#[error("claims verification error: {:?}", 0)]
|
||||||
|
ClaimsVerification(#[from] ClaimsVerificationError),
|
||||||
|
#[error("signing error: {:?}", 0)]
|
||||||
|
Signing(#[from] SigningError),
|
||||||
|
|
||||||
|
#[error("json serialization error: {:?}", 0)]
|
||||||
|
Json(#[from] serde_json::Error),
|
||||||
|
|
||||||
|
#[error("csrf token is invalid")]
|
||||||
|
CsrfTokenInvalid,
|
||||||
|
|
||||||
|
#[error("id token not found")]
|
||||||
|
IdTokenNotFound,
|
||||||
|
|
||||||
|
#[error("access token hash is invalid")]
|
||||||
|
AccessTokenHashInvalid,
|
||||||
|
|
||||||
|
#[error("just a redirect")]
|
||||||
|
Redirect((PrivateCookieJar, Redirect)),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for Error {
|
||||||
|
fn into_response(self) -> axum::response::Response {
|
||||||
|
match self {
|
||||||
|
Self::CsrfTokenInvalid => {
|
||||||
|
{ (StatusCode::BAD_REQUEST, "csrf token is invalid").into_response() }
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
Self::Redirect(redirect) => redirect.into_response(),
|
||||||
|
_ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
221
src/lib.rs
Normal file
221
src/lib.rs
Normal file
|
@ -0,0 +1,221 @@
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use axum::{
|
||||||
|
extract::{FromRef, FromRequestParts, Query},
|
||||||
|
http::request::Parts,
|
||||||
|
response::Redirect,
|
||||||
|
};
|
||||||
|
use axum_extra::extract::{
|
||||||
|
cookie::{Cookie, SameSite},
|
||||||
|
PrivateCookieJar,
|
||||||
|
};
|
||||||
|
use cookie::time::{Duration, OffsetDateTime};
|
||||||
|
use error::Error;
|
||||||
|
use openidconnect::{
|
||||||
|
core::{
|
||||||
|
CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreErrorResponseType,
|
||||||
|
CoreGenderClaim, CoreJsonWebKey, CoreJsonWebKeyType, CoreJsonWebKeyUse,
|
||||||
|
CoreJweContentEncryptionAlgorithm, CoreJwsSigningAlgorithm, CoreProviderMetadata,
|
||||||
|
CoreRevocableToken, CoreRevocationErrorResponse, CoreTokenIntrospectionResponse,
|
||||||
|
CoreTokenType,
|
||||||
|
},
|
||||||
|
reqwest::async_http_client,
|
||||||
|
AccessTokenHash, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken,
|
||||||
|
EmptyExtraTokenFields, IdTokenClaims, IdTokenFields, IssuerUrl, Nonce, OAuth2TokenResponse,
|
||||||
|
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope, StandardErrorResponse,
|
||||||
|
StandardTokenResponse, TokenResponse,
|
||||||
|
};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
pub use cookie::Key;
|
||||||
|
|
||||||
|
pub mod error;
|
||||||
|
|
||||||
|
const LOGIN_COOKIE_NAME: &str = "OIDC_LOGIN";
|
||||||
|
|
||||||
|
pub trait AdditionalClaims: openidconnect::AdditionalClaims + Clone + Sync + Send {}
|
||||||
|
|
||||||
|
type OidcTokenResponse<AC> = StandardTokenResponse<
|
||||||
|
IdTokenFields<
|
||||||
|
AC,
|
||||||
|
EmptyExtraTokenFields,
|
||||||
|
CoreGenderClaim,
|
||||||
|
CoreJweContentEncryptionAlgorithm,
|
||||||
|
CoreJwsSigningAlgorithm,
|
||||||
|
CoreJsonWebKeyType,
|
||||||
|
>,
|
||||||
|
CoreTokenType,
|
||||||
|
>;
|
||||||
|
|
||||||
|
pub type OidcClient<AC> = Client<
|
||||||
|
AC,
|
||||||
|
CoreAuthDisplay,
|
||||||
|
CoreGenderClaim,
|
||||||
|
CoreJweContentEncryptionAlgorithm,
|
||||||
|
CoreJwsSigningAlgorithm,
|
||||||
|
CoreJsonWebKeyType,
|
||||||
|
CoreJsonWebKeyUse,
|
||||||
|
CoreJsonWebKey,
|
||||||
|
CoreAuthPrompt,
|
||||||
|
StandardErrorResponse<CoreErrorResponseType>,
|
||||||
|
OidcTokenResponse<AC>,
|
||||||
|
CoreTokenType,
|
||||||
|
CoreTokenIntrospectionResponse,
|
||||||
|
CoreRevocableToken,
|
||||||
|
CoreRevocationErrorResponse,
|
||||||
|
>;
|
||||||
|
|
||||||
|
pub struct OidcApplication {
|
||||||
|
application_base: String,
|
||||||
|
issuer: IssuerUrl,
|
||||||
|
client_id: ClientId,
|
||||||
|
client_secret: Option<ClientSecret>,
|
||||||
|
scopes: Vec<String>,
|
||||||
|
cookie_key: Key,
|
||||||
|
}
|
||||||
|
impl OidcApplication {
|
||||||
|
pub fn new(
|
||||||
|
application_base: String,
|
||||||
|
issuer: String,
|
||||||
|
client_id: String,
|
||||||
|
client_secret: Option<String>,
|
||||||
|
scopes: Vec<String>,
|
||||||
|
cookie_key: Key,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
application_base,
|
||||||
|
issuer: IssuerUrl::new(issuer).unwrap(),
|
||||||
|
client_id: ClientId::new(client_id),
|
||||||
|
client_secret: client_secret.map(ClientSecret::new),
|
||||||
|
scopes,
|
||||||
|
cookie_key,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
async fn create_client<AC: AdditionalClaims>(
|
||||||
|
&self,
|
||||||
|
redirect: String,
|
||||||
|
) -> Result<OidcClient<AC>, Error> {
|
||||||
|
let provider_metadata =
|
||||||
|
CoreProviderMetadata::discover_async(self.issuer.clone(), async_http_client).await?;
|
||||||
|
let client = OidcClient::<AC>::from_provider_metadata(
|
||||||
|
provider_metadata,
|
||||||
|
self.client_id.clone(),
|
||||||
|
self.client_secret.clone(),
|
||||||
|
)
|
||||||
|
.set_redirect_uri(RedirectUrl::new(redirect)?);
|
||||||
|
Ok(client)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct EmptyAdditionalClaims {}
|
||||||
|
impl openidconnect::AdditionalClaims for EmptyAdditionalClaims {}
|
||||||
|
impl AdditionalClaims for EmptyAdditionalClaims {}
|
||||||
|
|
||||||
|
pub struct ClaimsExtractor<AC: AdditionalClaims>(pub IdTokenClaims<AC, CoreGenderClaim>);
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<S, AC> FromRequestParts<S> for ClaimsExtractor<AC>
|
||||||
|
where
|
||||||
|
S: Send + Sync,
|
||||||
|
AC: AdditionalClaims,
|
||||||
|
OidcApplication: FromRef<S>,
|
||||||
|
{
|
||||||
|
type Rejection = Error;
|
||||||
|
|
||||||
|
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||||
|
let application: OidcApplication = OidcApplication::from_ref(state);
|
||||||
|
let client = application
|
||||||
|
.create_client(format!(
|
||||||
|
"{}/{}",
|
||||||
|
application.application_base,
|
||||||
|
parts.uri.path()
|
||||||
|
))
|
||||||
|
.await?;
|
||||||
|
let mut jar = PrivateCookieJar::from_headers(&parts.headers, application.cookie_key);
|
||||||
|
|
||||||
|
let login_session = jar.get(LOGIN_COOKIE_NAME);
|
||||||
|
let query = Query::<OidcQuery>::from_request_parts(parts, state)
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
|
||||||
|
if let (Some(login_session), Some(Query(query))) = (login_session, query) {
|
||||||
|
let login_session: LoginSession = serde_json::from_str(login_session.value())?;
|
||||||
|
|
||||||
|
if login_session.csrf_token.secret() != &query.state {
|
||||||
|
return Err(Error::CsrfTokenInvalid);
|
||||||
|
}
|
||||||
|
|
||||||
|
let token_response = client
|
||||||
|
.exchange_code(AuthorizationCode::new(query.code.to_string()))
|
||||||
|
// Set the PKCE code verifier.
|
||||||
|
.set_pkce_verifier(login_session.pkce_verifier)
|
||||||
|
.request_async(async_http_client)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Extract the ID token claims after verifying its authenticity and nonce.
|
||||||
|
let id_token = token_response.id_token().ok_or(Error::IdTokenNotFound)?;
|
||||||
|
let claims = id_token.claims(&client.id_token_verifier(), &login_session.nonce)?;
|
||||||
|
|
||||||
|
// Verify the access token hash to ensure that the access token hasn't been substituted for
|
||||||
|
// another user's.
|
||||||
|
if let Some(expected_access_token_hash) = claims.access_token_hash() {
|
||||||
|
let actual_access_token_hash = AccessTokenHash::from_token(
|
||||||
|
token_response.access_token(),
|
||||||
|
&id_token.signing_alg()?,
|
||||||
|
)?;
|
||||||
|
if actual_access_token_hash != *expected_access_token_hash {
|
||||||
|
return Err(Error::AccessTokenHashInvalid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self(claims.clone()))
|
||||||
|
} else {
|
||||||
|
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||||
|
let (auth_url, csrf_token, nonce) = {
|
||||||
|
let mut auth = client.authorize_url(
|
||||||
|
CoreAuthenticationFlow::AuthorizationCode,
|
||||||
|
CsrfToken::new_random,
|
||||||
|
Nonce::new_random,
|
||||||
|
);
|
||||||
|
|
||||||
|
for scope in application.scopes.iter() {
|
||||||
|
auth = auth.add_scope(Scope::new(scope.to_string()));
|
||||||
|
}
|
||||||
|
auth.set_pkce_challenge(pkce_challenge).url()
|
||||||
|
};
|
||||||
|
|
||||||
|
let login_session = LoginSession {
|
||||||
|
nonce,
|
||||||
|
csrf_token,
|
||||||
|
pkce_verifier,
|
||||||
|
};
|
||||||
|
let login_session = serde_json::to_string(&login_session)?;
|
||||||
|
let mut cookie = Cookie::new(LOGIN_COOKIE_NAME, login_session);
|
||||||
|
cookie.set_same_site(SameSite::Lax);
|
||||||
|
cookie.set_secure(true);
|
||||||
|
cookie.set_http_only(true);
|
||||||
|
cookie.set_expires(OffsetDateTime::now_utc() + Duration::hours(1));
|
||||||
|
jar = jar.add(cookie);
|
||||||
|
|
||||||
|
Err(Error::Redirect((
|
||||||
|
jar,
|
||||||
|
Redirect::temporary(auth_url.as_str()),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OidcQuery {
|
||||||
|
code: String,
|
||||||
|
state: String,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
session_state: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
struct LoginSession {
|
||||||
|
nonce: Nonce,
|
||||||
|
csrf_token: CsrfToken,
|
||||||
|
pkce_verifier: PkceCodeVerifier,
|
||||||
|
}
|
Reference in a new issue