multipart support

This commit is contained in:
Paul Zinselmeyer 2023-10-20 11:26:38 +02:00
parent 5ac3a20b6c
commit d9b512f8d8
4 changed files with 108 additions and 14 deletions

29
Cargo.lock generated
View file

@ -113,6 +113,7 @@ dependencies = [
"matchit", "matchit",
"memchr", "memchr",
"mime", "mime",
"multer",
"percent-encoding", "percent-encoding",
"pin-project-lite", "pin-project-lite",
"rustversion", "rustversion",
@ -1087,7 +1088,7 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
dependencies = [ dependencies = [
"spin", "spin 0.5.2",
] ]
[[package]] [[package]]
@ -1173,6 +1174,24 @@ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]]
name = "multer"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01acbdc23469fd8fe07ab135923371d5f5a422fbf9c522158677c8eb15bc51c2"
dependencies = [
"bytes",
"encoding_rs",
"futures-util",
"http",
"httparse",
"log",
"memchr",
"mime",
"spin 0.9.8",
"version_check",
]
[[package]] [[package]]
name = "num-bigint" name = "num-bigint"
version = "0.4.4" version = "0.4.4"
@ -1681,7 +1700,7 @@ dependencies = [
"cc", "cc",
"libc", "libc",
"once_cell", "once_cell",
"spin", "spin 0.5.2",
"untrusted", "untrusted",
"web-sys", "web-sys",
"winapi", "winapi",
@ -2030,6 +2049,12 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
[[package]] [[package]]
name = "spki" name = "spki"
version = "0.7.2" version = "0.7.2"

View file

@ -9,7 +9,7 @@ edition = "2021"
tokio = { version = "1.33", features = ["full"] } tokio = { version = "1.33", features = ["full"] }
tokio-util = { version="0.7", features = ["io"]} tokio-util = { version="0.7", features = ["io"]}
futures-util = "0.3" futures-util = "0.3"
axum = {version="0.6", features=["macros", "headers"]} axum = {version="0.6", features=["macros", "headers", "multipart"]}
serde = "1.0" serde = "1.0"
toml = "0.8" toml = "0.8"
render = { git="https://github.com/render-rs/render.rs" } render = { git="https://github.com/render-rs/render.rs" }

View file

@ -32,6 +32,9 @@ pub enum Error {
#[error("oidc redirect")] #[error("oidc redirect")]
Oidc(Response), Oidc(Response),
#[error("invalid multipart")]
InvalidMultipart,
} }
impl IntoResponse for Error { impl IntoResponse for Error {
@ -44,6 +47,9 @@ impl IntoResponse for Error {
} }
Self::ParseTtl => (StatusCode::BAD_REQUEST, "invalid ttl class\n").into_response(), Self::ParseTtl => (StatusCode::BAD_REQUEST, "invalid ttl class\n").into_response(),
Self::Oidc(response) => response.into_response(), Self::Oidc(response) => response.into_response(),
Self::InvalidMultipart => {
(StatusCode::BAD_REQUEST, "invalid multipart data").into_response()
}
_ => { _ => {
error!("{:?}", self); error!("{:?}", self);
(StatusCode::INTERNAL_SERVER_ERROR, "internal server error\n").into_response() (StatusCode::INTERNAL_SERVER_ERROR, "internal server error\n").into_response()

View file

@ -7,11 +7,15 @@ use std::{
}; };
use axum::{ use axum::{
body::StreamBody, async_trait,
body::{HttpBody, StreamBody},
debug_handler, debug_handler,
extract::{BodyStream, FromRef, Path, Query, State}, extract::{BodyStream, FromRef, FromRequest, Multipart, Path, Query, State},
headers::ContentType, headers::ContentType,
http::{header, HeaderMap, StatusCode}, http::{
header::{self, CONTENT_TYPE},
HeaderMap, Request, StatusCode,
},
response::{Html, IntoResponse, Redirect, Response}, response::{Html, IntoResponse, Redirect, Response},
routing::get, routing::get,
Router, TypedHeader, Router, TypedHeader,
@ -20,6 +24,7 @@ use axum_oidc::{
jwt::{Claims, JwtApplication}, jwt::{Claims, JwtApplication},
oidc::{self, EmptyAdditionalClaims, OidcApplication, OidcExtractor}, oidc::{self, EmptyAdditionalClaims, OidcApplication, OidcExtractor},
}; };
use bytes::Bytes;
use chacha20::{ use chacha20::{
cipher::{KeyIvInit, StreamCipher}, cipher::{KeyIvInit, StreamCipher},
ChaCha20, ChaCha20,
@ -199,7 +204,7 @@ async fn post_item(
Query(params): Query<PostQuery>, Query(params): Query<PostQuery>,
State(app_state): State<AppState>, State(app_state): State<AppState>,
content_type: Option<TypedHeader<ContentType>>, content_type: Option<TypedHeader<ContentType>>,
mut stream: BodyStream, data: MultipartOrStream,
) -> HandlerResult<impl IntoResponse> { ) -> HandlerResult<impl IntoResponse> {
let phrase = Phrase::from_str(&phrase)?; let phrase = Phrase::from_str(&phrase)?;
let id = Id::from_phrase(&phrase, &app_state.id_salt); let id = Id::from_phrase(&phrase, &app_state.id_salt);
@ -221,6 +226,8 @@ async fn post_item(
let mut etag_hasher = Sha3_256::new(); let mut etag_hasher = Sha3_256::new();
let mut size = 0; let mut size = 0;
match data {
MultipartOrStream::Stream(mut stream) => {
while let Some(chunk) = stream.next().await { while let Some(chunk) = stream.next().await {
let mut buf = chunk.unwrap_or_default().to_vec(); let mut buf = chunk.unwrap_or_default().to_vec();
etag_hasher.update(&buf); etag_hasher.update(&buf);
@ -228,7 +235,25 @@ async fn post_item(
cipher.apply_keystream(&mut buf); cipher.apply_keystream(&mut buf);
writer.write_all(&buf).await?; writer.write_all(&buf).await?;
} }
}
MultipartOrStream::Multipart(mut multipart) => {
while let Some(mut field) = multipart
.next_field()
.await
.map_err(|_| Error::InvalidMultipart)?
{
if field.name().unwrap_or_default() == "file" {
while let Some(chunk) = field.chunk().await.unwrap_or_default() {
let mut buf = chunk.to_vec();
etag_hasher.update(&buf);
size += buf.len() as u64;
cipher.apply_keystream(&mut buf);
writer.write_all(&buf).await?;
}
}
}
}
}
writer.flush().await?; writer.flush().await?;
let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(); let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
@ -325,3 +350,41 @@ async fn get_item(
Ok((StatusCode::OK, headers, body).into_response()) Ok((StatusCode::OK, headers, body).into_response())
} }
} }
enum MultipartOrStream {
Multipart(Multipart),
Stream(BodyStream),
}
#[async_trait]
impl<S, B> FromRequest<S, B> for MultipartOrStream
where
B: Send + 'static + HttpBody,
S: Send + Sync,
Bytes: From<<B as HttpBody>::Data>,
<B as HttpBody>::Error:
Send + Sync + Into<Box<(dyn std::error::Error + Send + Sync + 'static)>>,
{
type Rejection = Response;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let is_multipart = req
.headers()
.get(CONTENT_TYPE)
.map(|x| x == "multipart/form-data")
.unwrap_or_default();
if is_multipart {
Ok(Self::Multipart(
Multipart::from_request(req, state)
.await
.map_err(|x| x.into_response())?,
))
} else {
Ok(Self::Stream(
BodyStream::from_request(req, state)
.await
.map_err(|x| x.into_response())?,
))
}
}
}