diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 4cd7e0a..e7a9196 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -10,5 +10,10 @@ clap = { version="4.4", features = ["derive"] } reqwest = { version="0.11", features = ["rustls-tls", "stream"], default-features=false} openidconnect = "3.4" thiserror = "1.0" -confy = "0.5" serde = { version="1.0", features = [ "derive" ] } +axum = "0.6" +tokio = { version = "1.33", features = ["full"] } +open = "5.0" +tokio-util = { version="0.7.9", features = ["io"]} +dirs = "5.0" +confy = "0.5" diff --git a/cli/src/auth.rs b/cli/src/auth.rs new file mode 100644 index 0000000..33ea1af --- /dev/null +++ b/cli/src/auth.rs @@ -0,0 +1,174 @@ +use std::sync::Arc; + +use axum::{ + extract::{Query, State}, + response::{Html, IntoResponse}, + routing::get, + Router, +}; +use openidconnect::{ + core::{CoreAuthenticationFlow, CoreClient, CoreErrorResponseType, CoreProviderMetadata}, + reqwest::async_http_client, + AccessTokenHash, AuthorizationCode, ClaimsVerificationError, ClientId, CsrfToken, + DiscoveryError, IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, RedirectUrl, + RefreshToken, RequestTokenError, Scope, SigningError, StandardErrorResponse, TokenResponse, +}; +use serde::Deserialize; +use thiserror::Error; +use tokio::sync::mpsc; + +#[derive(Error, Debug)] +pub enum Error { + #[error("url parse error: {:?}", 0)] + UrlParse(#[from] openidconnect::url::ParseError), + + #[error("discovery error: {:?}", 0)] + Discovery(#[from] DiscoveryError>), + + #[error("request token error: {:?}", 0)] + RequestToken( + #[from] + RequestTokenError< + openidconnect::reqwest::Error, + StandardErrorResponse, + >, + ), + + #[error("claims verification error: {:?}", 0)] + ClaimsVerification(#[from] ClaimsVerificationError), + + #[error("signing error: {:?}", 0)] + Signing(#[from] SigningError), + + #[error("server did not return an id token")] + NoIdToken, + + #[error("invalid access token")] + InvalidAccessToken, + + #[error("no response received")] + NoResponse, + + #[error("csrf mismatch")] + CsrfMismatch, +} + +#[derive(Debug, Deserialize)] +struct ResponseData { + pub code: String, + pub state: String, +} + +pub(crate) async fn login( + issuer: &str, + client_id: &str, + scopes: &[String], + refresh_token: &mut Option, +) -> Result { + let provider_metadata = CoreProviderMetadata::discover_async( + IssuerUrl::new(issuer.to_string())?, + async_http_client, + ) + .await?; + + // Create an OpenID Connect client by specifying the client ID, client secret, authorization URL + // and token URL. + let client = CoreClient::from_provider_metadata( + provider_metadata, + ClientId::new(client_id.to_string()), + None, + ) + // Set the URL the user will be redirected to after the authorization process. + .set_redirect_uri(RedirectUrl::new("http://[::1]:8080".to_string())?); + + // Generate a PKCE challenge. + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + + if let Some(refresh_token) = refresh_token { + if let Ok(token_response) = client + .exchange_refresh_token(&RefreshToken::new(refresh_token.to_string())) + .request_async(async_http_client) + .await + { + eprintln!("authenticated with oidc provider"); + return Ok(token_response.access_token().secret().clone()); + } + } + + // Generate the full authorization URL. + let mut auth = client.authorize_url( + CoreAuthenticationFlow::AuthorizationCode, + CsrfToken::new_random, + Nonce::new_random, + ); + + for scope in scopes { + auth = auth.add_scope(Scope::new(scope.to_string())); + } + let (auth_url, csrf_token, nonce) = auth + // Set the PKCE code challenge. + .set_pkce_challenge(pkce_challenge) + .url(); + open::that(auth_url.to_string()).unwrap(); + eprintln!("a browser should have been opened with the url {auth_url}. please login with your oidc provider."); + + let (fuse_tx, mut fuse_rx) = mpsc::channel::(1); + let app = Router::new() + .route("/", get(handle_post)) + .with_state(Arc::new(fuse_tx)); + + let server = axum::Server::bind(&"[::1]:8080".parse().unwrap()).serve(app.into_make_service()); + + let data = tokio::select! { + x = fuse_rx.recv() => { + x + } + _ = server => { + None + } + }; + + let data = data.ok_or(Error::NoResponse)?; + + // match csrf_state + + if *csrf_token.secret() != data.state { + return Err(Error::CsrfMismatch); + } + + let token_response = client + .exchange_code(AuthorizationCode::new(data.code)) + // Set the PKCE code verifier. + .set_pkce_verifier(pkce_verifier) + .request_async(async_http_client) + .await?; + + // Extract the ID token claims after verifying its authenticity and nonce. + let id_token = token_response.id_token().ok_or_else(|| Error::NoIdToken)?; + let claims = id_token.claims(&client.id_token_verifier(), &nonce)?; + + // Verify the access token hash to ensure that the access token hasn't been substituted for + // another user's. + if let Some(expected_access_token_hash) = claims.access_token_hash() { + let actual_access_token_hash = + AccessTokenHash::from_token(token_response.access_token(), &id_token.signing_alg()?)?; + if actual_access_token_hash != *expected_access_token_hash { + return Err(Error::InvalidAccessToken); + } + } + + if let Some(new_refresh_token) = token_response.refresh_token() { + *refresh_token = Some(new_refresh_token.secret().to_string()); + } + + eprintln!("authenticated with oidc provider"); + Ok(token_response.access_token().secret().clone()) +} + +async fn handle_post( + State(fuse_tx): State>>, + Query(data): Query, +) -> impl IntoResponse { + fuse_tx.clone().send(data).await; + Html("Die Anmeldung war erfolgreich. Du kannst dieses Fenster jetzt schließen.") +} diff --git a/cli/src/main.rs b/cli/src/main.rs index 1fa186f..529294c 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -1,57 +1,94 @@ -use std::path::PathBuf; - -use clap::{Parser, Subcommand}; +use clap::Parser; +use reqwest::{Body, Url}; use serde::{Deserialize, Serialize}; +use tokio::io::stdin; +use tokio_util::io::ReaderStream; -#[derive(Debug, Serialize, Deserialize)] +use crate::auth::login; + +mod auth; + +#[derive(Serialize, Deserialize, Debug)] pub struct Config { - server: String, - client_id: String, - client_secret: String, - claims: Vec, - challenge_port: u32, + pub refresh_token: Option, + pub binurl: String, + pub issuer: String, + pub client_id: String, + pub scopes: Vec, } #[derive(Debug, Parser)] pub struct Args { - #[arg(short, long, value_name = "FILE")] - config: Option, + #[arg(short, long)] + content_type: Option, - #[command(subcommand)] - command: Option, + #[arg(short, long)] + ttl: Option, } -#[derive(Debug, Subcommand)] -pub enum Command { - Create { - #[arg(short, long, action)] - stdin: bool, - }, - Upload {}, - Login { - /// challenge port to listen to - #[arg(short, long, value_name = "PORT")] - port: Option, - - /// OIDC server - #[arg(long, value_name = "URL")] - server: Option, - - /// OIDC client id - #[arg(long)] - client: Option, - - /// OIDC client secret - #[arg(long)] - secret: Option, - - /// OIDC claims - #[arg(long)] - claims: Option>, - }, +impl Default for Config { + fn default() -> Self { + Self { + refresh_token: None, + binurl: "https://bin.zettoit.eu".to_string(), + issuer: "https://auth.zettoit.eu/realms/zettoit".to_string(), + client_id: "binctl".to_string(), + scopes: vec!["zettoit-bin".to_string()], + } + } } -fn main() { +#[tokio::main] +async fn main() { + let mut cfg: Config = confy::load("binctl", None).unwrap_or_default(); + let args = Args::parse(); - dbg!(args); + let access_token = login( + &cfg.issuer, + &cfg.client_id, + cfg.scopes.as_slice(), + &mut cfg.refresh_token, + ) + .await + .unwrap(); + let mut bin = create_bin(&cfg.binurl, &access_token).await.unwrap(); + eprintln!("created bin at {}. uploading...", bin); + bin.set_query(args.ttl.map(|x| format!("ttl={}", x)).as_deref()); + + upload_to_bin( + bin.as_ref(), + &args + .content_type + .unwrap_or("application/octet-stream".to_string()), + ) + .await + .unwrap(); + + let _ = confy::store("binctl", None, cfg); + bin.set_query(None); + print!("{bin}"); +} + +async fn create_bin(binurl: &str, access_token: &str) -> Result { + let client = reqwest::Client::new(); + + Ok(client + .get(binurl) + .header("Authorization", format!("Bearer {}", access_token)) + .send() + .await? + .url() + .to_owned()) +} + +async fn upload_to_bin(url: &str, content_type: &str) -> Result<(), reqwest::Error> { + let client = reqwest::Client::new(); + + client + .post(url) + .header("Content-Type", content_type) + .body(Body::wrap_stream(ReaderStream::new(stdin()))) + .send() + .await?; + Ok(()) } diff --git a/flake.nix b/flake.nix index 4997bb7..c66048e 100644 --- a/flake.nix +++ b/flake.nix @@ -26,20 +26,22 @@ nixpkgs.lib.genAttrs [ "x86_64-linux" "aarch64-linux" - ] (system: let + ] (system: function system nixpkgs.legacyPackages.${system}); + in rec { + packages = forAllSystems(system: syspkgs: let pkgs = import nixpkgs { inherit system; overlays = [ (import rust-overlay) ]; }; rustToolchain = pkgs.rust-bin.stable.latest.default; - markdownFilter = path: _type: builtins.match ".*md$" path != null; - markdownOrCargo = path: type: (markdownFilter path type) || (craneLib.filterCargoSources path type); - craneLib = (crane.mkLib pkgs).overrideToolchain rustToolchain; src = pkgs.lib.cleanSourceWith { src = craneLib.path ./.; - filter = markdownOrCargo; + filter = path: type: + (pkgs.lib.hasSuffix "\.md" path) || + (craneLib.filterCargoSources path type) + ; }; nativeBuildInputs = with pkgs; [ rustToolchain pkg-config ]; @@ -52,18 +54,20 @@ bin = craneLib.buildPackage (commonArgs // { inherit cargoArtifacts; + pname = "bin"; }); - in function { - inherit bin pkgs; - }); - in { - packages = forAllSystems({pkgs, bin}: { - inherit bin; + binctl = craneLib.buildPackage (commonArgs // { + inherit cargoArtifacts; + pname = "binctl"; + }); + in { + inherit bin binctl; default = bin; }); - devShells = forAllSystems({pkgs, bin}: pkgs.mkShell { - inputsFrom = bin; + devShells = forAllSystems(system: pkgs: pkgs.mkShell { + inputsFrom = [packages.${system}.bin packages.${system}.binctl]; }); - hydraJobs."build" = forAllSystems({pkgs, bin}: bin); + hydraJobs."bin" = forAllSystems(system: pkgs: packages.${system}.bin); + hydraJobs."binctl" = forAllSystems(system: pkgs: packages.${system}.binctl); }; }