init
This commit is contained in:
commit
75ed3b861a
4 changed files with 294 additions and 0 deletions
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
/target
|
||||
/Cargo.lock
|
17
Cargo.toml
Normal file
17
Cargo.toml
Normal file
|
@ -0,0 +1,17 @@
|
|||
[package]
|
||||
name = "axum_oidc"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
axum = "0.6"
|
||||
axum-extra = {version="0.7", features=["cookie", "cookie-private"]}
|
||||
cookie = "0.17"
|
||||
openidconnect = "3.0"
|
||||
async-trait = "0.1"
|
||||
serde = "1.0"
|
||||
thiserror = "1.0"
|
||||
reqwest = { version="0.11", default_features=false}
|
||||
serde_json = "1.0"
|
54
src/error.rs
Normal file
54
src/error.rs
Normal file
|
@ -0,0 +1,54 @@
|
|||
use axum::response::{IntoResponse, Redirect};
|
||||
use axum_extra::extract::PrivateCookieJar;
|
||||
use openidconnect::{
|
||||
core::CoreErrorResponseType, url::ParseError, ClaimsVerificationError, DiscoveryError,
|
||||
SigningError, StandardErrorResponse,
|
||||
};
|
||||
use reqwest::StatusCode;
|
||||
|
||||
type RequestTokenError = openidconnect::RequestTokenError<
|
||||
openidconnect::reqwest::Error<reqwest::Error>,
|
||||
StandardErrorResponse<CoreErrorResponseType>,
|
||||
>;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum Error {
|
||||
#[error("discovery error: {:?}", 0)]
|
||||
Discovery(#[from] DiscoveryError<openidconnect::reqwest::Error<reqwest::Error>>),
|
||||
#[error("parse error: {:?}", 0)]
|
||||
Parse(#[from] ParseError),
|
||||
#[error("request token error: {:?}", 0)]
|
||||
RequestToken(#[from] RequestTokenError),
|
||||
#[error("claims verification error: {:?}", 0)]
|
||||
ClaimsVerification(#[from] ClaimsVerificationError),
|
||||
#[error("signing error: {:?}", 0)]
|
||||
Signing(#[from] SigningError),
|
||||
|
||||
#[error("json serialization error: {:?}", 0)]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
#[error("csrf token is invalid")]
|
||||
CsrfTokenInvalid,
|
||||
|
||||
#[error("id token not found")]
|
||||
IdTokenNotFound,
|
||||
|
||||
#[error("access token hash is invalid")]
|
||||
AccessTokenHashInvalid,
|
||||
|
||||
#[error("just a redirect")]
|
||||
Redirect((PrivateCookieJar, Redirect)),
|
||||
}
|
||||
|
||||
impl IntoResponse for Error {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
Self::CsrfTokenInvalid => {
|
||||
{ (StatusCode::BAD_REQUEST, "csrf token is invalid").into_response() }
|
||||
.into_response()
|
||||
}
|
||||
Self::Redirect(redirect) => redirect.into_response(),
|
||||
_ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(),
|
||||
}
|
||||
}
|
||||
}
|
221
src/lib.rs
Normal file
221
src/lib.rs
Normal file
|
@ -0,0 +1,221 @@
|
|||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
extract::{FromRef, FromRequestParts, Query},
|
||||
http::request::Parts,
|
||||
response::Redirect,
|
||||
};
|
||||
use axum_extra::extract::{
|
||||
cookie::{Cookie, SameSite},
|
||||
PrivateCookieJar,
|
||||
};
|
||||
use cookie::time::{Duration, OffsetDateTime};
|
||||
use error::Error;
|
||||
use openidconnect::{
|
||||
core::{
|
||||
CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreErrorResponseType,
|
||||
CoreGenderClaim, CoreJsonWebKey, CoreJsonWebKeyType, CoreJsonWebKeyUse,
|
||||
CoreJweContentEncryptionAlgorithm, CoreJwsSigningAlgorithm, CoreProviderMetadata,
|
||||
CoreRevocableToken, CoreRevocationErrorResponse, CoreTokenIntrospectionResponse,
|
||||
CoreTokenType,
|
||||
},
|
||||
reqwest::async_http_client,
|
||||
AccessTokenHash, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken,
|
||||
EmptyExtraTokenFields, IdTokenClaims, IdTokenFields, IssuerUrl, Nonce, OAuth2TokenResponse,
|
||||
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope, StandardErrorResponse,
|
||||
StandardTokenResponse, TokenResponse,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub use cookie::Key;
|
||||
|
||||
pub mod error;
|
||||
|
||||
const LOGIN_COOKIE_NAME: &str = "OIDC_LOGIN";
|
||||
|
||||
pub trait AdditionalClaims: openidconnect::AdditionalClaims + Clone + Sync + Send {}
|
||||
|
||||
type OidcTokenResponse<AC> = StandardTokenResponse<
|
||||
IdTokenFields<
|
||||
AC,
|
||||
EmptyExtraTokenFields,
|
||||
CoreGenderClaim,
|
||||
CoreJweContentEncryptionAlgorithm,
|
||||
CoreJwsSigningAlgorithm,
|
||||
CoreJsonWebKeyType,
|
||||
>,
|
||||
CoreTokenType,
|
||||
>;
|
||||
|
||||
pub type OidcClient<AC> = Client<
|
||||
AC,
|
||||
CoreAuthDisplay,
|
||||
CoreGenderClaim,
|
||||
CoreJweContentEncryptionAlgorithm,
|
||||
CoreJwsSigningAlgorithm,
|
||||
CoreJsonWebKeyType,
|
||||
CoreJsonWebKeyUse,
|
||||
CoreJsonWebKey,
|
||||
CoreAuthPrompt,
|
||||
StandardErrorResponse<CoreErrorResponseType>,
|
||||
OidcTokenResponse<AC>,
|
||||
CoreTokenType,
|
||||
CoreTokenIntrospectionResponse,
|
||||
CoreRevocableToken,
|
||||
CoreRevocationErrorResponse,
|
||||
>;
|
||||
|
||||
pub struct OidcApplication {
|
||||
application_base: String,
|
||||
issuer: IssuerUrl,
|
||||
client_id: ClientId,
|
||||
client_secret: Option<ClientSecret>,
|
||||
scopes: Vec<String>,
|
||||
cookie_key: Key,
|
||||
}
|
||||
impl OidcApplication {
|
||||
pub fn new(
|
||||
application_base: String,
|
||||
issuer: String,
|
||||
client_id: String,
|
||||
client_secret: Option<String>,
|
||||
scopes: Vec<String>,
|
||||
cookie_key: Key,
|
||||
) -> Self {
|
||||
Self {
|
||||
application_base,
|
||||
issuer: IssuerUrl::new(issuer).unwrap(),
|
||||
client_id: ClientId::new(client_id),
|
||||
client_secret: client_secret.map(ClientSecret::new),
|
||||
scopes,
|
||||
cookie_key,
|
||||
}
|
||||
}
|
||||
async fn create_client<AC: AdditionalClaims>(
|
||||
&self,
|
||||
redirect: String,
|
||||
) -> Result<OidcClient<AC>, Error> {
|
||||
let provider_metadata =
|
||||
CoreProviderMetadata::discover_async(self.issuer.clone(), async_http_client).await?;
|
||||
let client = OidcClient::<AC>::from_provider_metadata(
|
||||
provider_metadata,
|
||||
self.client_id.clone(),
|
||||
self.client_secret.clone(),
|
||||
)
|
||||
.set_redirect_uri(RedirectUrl::new(redirect)?);
|
||||
Ok(client)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmptyAdditionalClaims {}
|
||||
impl openidconnect::AdditionalClaims for EmptyAdditionalClaims {}
|
||||
impl AdditionalClaims for EmptyAdditionalClaims {}
|
||||
|
||||
pub struct ClaimsExtractor<AC: AdditionalClaims>(pub IdTokenClaims<AC, CoreGenderClaim>);
|
||||
|
||||
#[async_trait]
|
||||
impl<S, AC> FromRequestParts<S> for ClaimsExtractor<AC>
|
||||
where
|
||||
S: Send + Sync,
|
||||
AC: AdditionalClaims,
|
||||
OidcApplication: FromRef<S>,
|
||||
{
|
||||
type Rejection = Error;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
let application: OidcApplication = OidcApplication::from_ref(state);
|
||||
let client = application
|
||||
.create_client(format!(
|
||||
"{}/{}",
|
||||
application.application_base,
|
||||
parts.uri.path()
|
||||
))
|
||||
.await?;
|
||||
let mut jar = PrivateCookieJar::from_headers(&parts.headers, application.cookie_key);
|
||||
|
||||
let login_session = jar.get(LOGIN_COOKIE_NAME);
|
||||
let query = Query::<OidcQuery>::from_request_parts(parts, state)
|
||||
.await
|
||||
.ok();
|
||||
|
||||
if let (Some(login_session), Some(Query(query))) = (login_session, query) {
|
||||
let login_session: LoginSession = serde_json::from_str(login_session.value())?;
|
||||
|
||||
if login_session.csrf_token.secret() != &query.state {
|
||||
return Err(Error::CsrfTokenInvalid);
|
||||
}
|
||||
|
||||
let token_response = client
|
||||
.exchange_code(AuthorizationCode::new(query.code.to_string()))
|
||||
// Set the PKCE code verifier.
|
||||
.set_pkce_verifier(login_session.pkce_verifier)
|
||||
.request_async(async_http_client)
|
||||
.await?;
|
||||
|
||||
// Extract the ID token claims after verifying its authenticity and nonce.
|
||||
let id_token = token_response.id_token().ok_or(Error::IdTokenNotFound)?;
|
||||
let claims = id_token.claims(&client.id_token_verifier(), &login_session.nonce)?;
|
||||
|
||||
// Verify the access token hash to ensure that the access token hasn't been substituted for
|
||||
// another user's.
|
||||
if let Some(expected_access_token_hash) = claims.access_token_hash() {
|
||||
let actual_access_token_hash = AccessTokenHash::from_token(
|
||||
token_response.access_token(),
|
||||
&id_token.signing_alg()?,
|
||||
)?;
|
||||
if actual_access_token_hash != *expected_access_token_hash {
|
||||
return Err(Error::AccessTokenHashInvalid);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self(claims.clone()))
|
||||
} else {
|
||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
let (auth_url, csrf_token, nonce) = {
|
||||
let mut auth = client.authorize_url(
|
||||
CoreAuthenticationFlow::AuthorizationCode,
|
||||
CsrfToken::new_random,
|
||||
Nonce::new_random,
|
||||
);
|
||||
|
||||
for scope in application.scopes.iter() {
|
||||
auth = auth.add_scope(Scope::new(scope.to_string()));
|
||||
}
|
||||
auth.set_pkce_challenge(pkce_challenge).url()
|
||||
};
|
||||
|
||||
let login_session = LoginSession {
|
||||
nonce,
|
||||
csrf_token,
|
||||
pkce_verifier,
|
||||
};
|
||||
let login_session = serde_json::to_string(&login_session)?;
|
||||
let mut cookie = Cookie::new(LOGIN_COOKIE_NAME, login_session);
|
||||
cookie.set_same_site(SameSite::Lax);
|
||||
cookie.set_secure(true);
|
||||
cookie.set_http_only(true);
|
||||
cookie.set_expires(OffsetDateTime::now_utc() + Duration::hours(1));
|
||||
jar = jar.add(cookie);
|
||||
|
||||
Err(Error::Redirect((
|
||||
jar,
|
||||
Redirect::temporary(auth_url.as_str()),
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OidcQuery {
|
||||
code: String,
|
||||
state: String,
|
||||
#[allow(dead_code)]
|
||||
session_state: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct LoginSession {
|
||||
nonce: Nonce,
|
||||
csrf_token: CsrfToken,
|
||||
pkce_verifier: PkceCodeVerifier,
|
||||
}
|
Reference in a new issue