use std::{ borrow::Borrow, pin::Pin, task::{Context, Poll}, }; use bytes::{Bytes, BytesMut}; use chacha20::{ cipher::{KeyIvInit, StreamCipher}, ChaCha20, }; use futures_util::Stream; use log::{debug, warn}; use pin_project_lite::pin_project; use sha3::{Digest, Sha3_256}; use tokio::io::AsyncRead; use tokio_util::io::poll_read_buf; use crate::{ metadata::Metadata, util::{Id, Key, Nonce}, }; pin_project! { pub(crate) struct DecryptingStream { #[pin] reader: Option, buf: BytesMut, // chunk size capacity: usize, // chacha20 cipher cipher: ChaCha20, // hasher to verify file integrity hasher: Sha3_256, // hash to verify against target_hash: String, // id of the file for logging purposes id: Id, // total file size size: u64, // current position of the "reading head" progress: u64 } } impl DecryptingStream { pub(crate) fn new(reader: R, id: Id, metadata: &Metadata, key: &Key, nonce: &Nonce) -> Self { let cipher = ChaCha20::new(key.borrow(), nonce.borrow()); Self { reader: Some(reader), buf: BytesMut::new(), capacity: 1 << 22, // 4 MiB cipher, hasher: Sha3_256::new(), target_hash: metadata.etag.clone().unwrap_or_default(), id, size: metadata.size.unwrap_or_default(), progress: 0, } } } impl Stream for DecryptingStream { type Item = std::io::Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut this = self.as_mut().project(); let reader = match this.reader.as_pin_mut() { Some(r) => r, None => return Poll::Ready(None), }; if this.buf.capacity() == 0 { this.buf.reserve(*this.capacity); } match poll_read_buf(reader, cx, &mut this.buf) { Poll::Pending => Poll::Pending, Poll::Ready(Err(err)) => { debug!("failed to send bin {}", this.id); self.project().reader.set(None); Poll::Ready(Some(Err(err))) } Poll::Ready(Ok(0)) => { if self.progress_check() == DecryptingStreamProgress::Failed { // The hash is invalid, the file has been tampered with. Close reader and stream causing the download to fail self.project().reader.set(None); return Poll::Ready(None); }; self.project().reader.set(None); Poll::Ready(None) } Poll::Ready(Ok(n)) => { let mut chunk = this.buf.split(); // decrypt the chunk using chacha this.cipher.apply_keystream(&mut chunk); // update the sha3 hasher this.hasher.update(&chunk); // track progress *this.progress += n as u64; if self.progress_check() == DecryptingStreamProgress::Failed { // The hash is invalid, the file has been tampered with. Close reader and stream causing the download to fail warn!("bin {} is corrupted! transmission failed", self.id); self.project().reader.set(None); return Poll::Ready(None); }; Poll::Ready(Some(Ok(chunk.freeze()))) } } } } impl DecryptingStream { /// checks if the hash is correct when the last byte has been read fn progress_check(&self) -> DecryptingStreamProgress { if self.progress >= self.size { let hash = hex::encode(self.hasher.clone().finalize()); if hash != self.target_hash { DecryptingStreamProgress::Failed } else { DecryptingStreamProgress::Finished } } else { DecryptingStreamProgress::Running } } } #[derive(PartialEq, Eq)] enum DecryptingStreamProgress { Finished, Failed, Running, }