133 lines
4.2 KiB
Rust
133 lines
4.2 KiB
Rust
use std::{
|
|
borrow::Borrow,
|
|
pin::Pin,
|
|
task::{Context, Poll},
|
|
};
|
|
|
|
use bytes::{Bytes, BytesMut};
|
|
use chacha20::{
|
|
cipher::{KeyIvInit, StreamCipher},
|
|
ChaCha20,
|
|
};
|
|
use futures_util::Stream;
|
|
use log::{debug, warn};
|
|
use pin_project_lite::pin_project;
|
|
use sha3::{Digest, Sha3_256};
|
|
use tokio::io::AsyncRead;
|
|
use tokio_util::io::poll_read_buf;
|
|
|
|
use crate::{
|
|
metadata::Metadata,
|
|
util::{Id, Key, Nonce},
|
|
};
|
|
pin_project! {
|
|
pub(crate) struct DecryptingStream<R> {
|
|
#[pin]
|
|
reader: Option<R>,
|
|
buf: BytesMut,
|
|
// chunk size
|
|
capacity: usize,
|
|
// chacha20 cipher
|
|
cipher: ChaCha20,
|
|
// hasher to verify file integrity
|
|
hasher: Sha3_256,
|
|
// hash to verify against
|
|
target_hash: String,
|
|
// id of the file for logging purposes
|
|
id: Id,
|
|
// total file size
|
|
size: u64,
|
|
// current position of the "reading head"
|
|
progress: u64
|
|
}
|
|
}
|
|
impl<R: AsyncRead> DecryptingStream<R> {
|
|
pub(crate) fn new(reader: R, id: Id, metadata: &Metadata, key: &Key, nonce: &Nonce) -> Self {
|
|
let cipher = ChaCha20::new(key.borrow(), nonce.borrow());
|
|
Self {
|
|
reader: Some(reader),
|
|
buf: BytesMut::new(),
|
|
capacity: 1 << 22, // 4 MiB
|
|
cipher,
|
|
hasher: Sha3_256::new(),
|
|
target_hash: metadata.etag.clone().unwrap_or_default(),
|
|
id,
|
|
size: metadata.size.unwrap_or_default(),
|
|
progress: 0,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<R: AsyncRead> Stream for DecryptingStream<R> {
|
|
type Item = std::io::Result<Bytes>;
|
|
|
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
|
let mut this = self.as_mut().project();
|
|
|
|
let reader = match this.reader.as_pin_mut() {
|
|
Some(r) => r,
|
|
None => return Poll::Ready(None),
|
|
};
|
|
|
|
if this.buf.capacity() == 0 {
|
|
this.buf.reserve(*this.capacity);
|
|
}
|
|
|
|
match poll_read_buf(reader, cx, &mut this.buf) {
|
|
Poll::Pending => Poll::Pending,
|
|
Poll::Ready(Err(err)) => {
|
|
debug!("failed to send bin {}", this.id);
|
|
self.project().reader.set(None);
|
|
Poll::Ready(Some(Err(err)))
|
|
}
|
|
Poll::Ready(Ok(0)) => {
|
|
if self.progress_check() == DecryptingStreamProgress::Failed {
|
|
// The hash is invalid, the file has been tampered with. Close reader and stream causing the download to fail
|
|
self.project().reader.set(None);
|
|
return Poll::Ready(None);
|
|
};
|
|
self.project().reader.set(None);
|
|
Poll::Ready(None)
|
|
}
|
|
Poll::Ready(Ok(n)) => {
|
|
let mut chunk = this.buf.split();
|
|
// decrypt the chunk using chacha
|
|
this.cipher.apply_keystream(&mut chunk);
|
|
// update the sha3 hasher
|
|
this.hasher.update(&chunk);
|
|
// track progress
|
|
*this.progress += n as u64;
|
|
if self.progress_check() == DecryptingStreamProgress::Failed {
|
|
// The hash is invalid, the file has been tampered with. Close reader and stream causing the download to fail
|
|
warn!("bin {} is corrupted! transmission failed", self.id);
|
|
self.project().reader.set(None);
|
|
return Poll::Ready(None);
|
|
};
|
|
Poll::Ready(Some(Ok(chunk.freeze())))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<R: AsyncRead> DecryptingStream<R> {
|
|
/// checks if the hash is correct when the last byte has been read
|
|
fn progress_check(&self) -> DecryptingStreamProgress {
|
|
if self.progress >= self.size {
|
|
let hash = hex::encode(self.hasher.clone().finalize());
|
|
if hash != self.target_hash {
|
|
DecryptingStreamProgress::Failed
|
|
} else {
|
|
DecryptingStreamProgress::Finished
|
|
}
|
|
} else {
|
|
DecryptingStreamProgress::Running
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(PartialEq, Eq)]
|
|
enum DecryptingStreamProgress {
|
|
Finished,
|
|
Failed,
|
|
Running,
|
|
}
|