diff --git a/src/stream/crypto_stream.rs b/src/stream/crypto_stream.rs index 5ee93a7..b06bc7c 100644 --- a/src/stream/crypto_stream.rs +++ b/src/stream/crypto_stream.rs @@ -6,57 +6,87 @@ //! Key design principles: //! - Explicit state machines for all async operations //! - Never lose data on partial reads/writes -//! - Honest reporting of bytes written +//! - Honest reporting of bytes written (AsyncWrite contract) //! - Bounded internal buffers with backpressure +//! +//! AES-CTR is a stream cipher: the keystream position must advance exactly by the +//! number of plaintext bytes that are *accepted* (written or buffered). +//! +//! This implementation guarantees: +//! - CTR state never "drifts" +//! - never accept plaintext unless we can guarantee that all corresponding ciphertext +//! is either written to upstream or stored in our pending buffer +//! - when upstream is pending -> ciphertext is buffered/bounded and backpressure is applied +//! +//! ======================= +//! Writer state machine +//! ======================= +//! +//! ┌──────────┐ write buf ┌──────────┐ +//! │ Idle │ ---------------> │ Flushing │ +//! │ │ <--------------- │ │ +//! └──────────┘ drained └──────────┘ +//! │ │ +//! │ errors │ +//! ▼ ▼ +//! ┌────────────────────────────────────────┐ +//! │ Poisoned │ +//! └────────────────────────────────────────┘ +//! +//! Backpressure +//! - pending ciphertext buffer is bounded (MAX_PENDING_WRITE) +//! - pending is full and upstream is pending +//! -> poll_write returns Poll::Pending +//! -> do not accept any plaintext +//! +//! Performance +//! - fast path when pending is empty: encrypt into scratch and try upstream +//! - if upstream Pending/partial => move remainder into pending without re-encrypting +//! - when upstream is Pending but pending still has room: accept `to_accept` bytes and +//! encrypt+append ciphertext directly into pending (in-place encryption of appended range) -use bytes::{Bytes, BytesMut, BufMut}; -use std::io::{self, Error, ErrorKind, Result}; +use bytes::{Bytes, BytesMut}; +use std::io::{self, ErrorKind, Result}; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tracing::{debug, trace, warn}; use crate::crypto::AesCtr; -use crate::error::StreamError; -use super::state::{StreamState, ReadBuffer, WriteBuffer, YieldBuffer}; +use super::state::{StreamState, YieldBuffer}; // ============= Constants ============= -/// Maximum size for pending write buffer (256KB) -const MAX_PENDING_WRITE: usize = 256 * 1024; +/// Maximum size for pending ciphertext buffer (bounded backpressure). +/// 512 KiB tends to work well for mobile networks and avoids huge latency spikes. +const MAX_PENDING_WRITE: usize = 524_288; -/// Default read buffer capacity +/// Default read buffer capacity (reader mostly decrypts in-place into caller buffer). const DEFAULT_READ_CAPACITY: usize = 16 * 1024; // ============= CryptoReader State ============= -/// State machine states for CryptoReader #[derive(Debug)] enum CryptoReaderState { /// Ready to read new data Idle, - + /// Have decrypted data ready to yield to caller - Yielding { - /// Buffer containing decrypted data - buffer: YieldBuffer, - }, - + Yielding { buffer: YieldBuffer }, + /// Stream encountered an error and cannot be used - Poisoned { - /// The error that caused poisoning (taken on first access) - error: Option, - }, + Poisoned { error: Option }, } impl StreamState for CryptoReaderState { fn is_terminal(&self) -> bool { matches!(self, Self::Poisoned { .. }) } - + fn is_poisoned(&self) -> bool { matches!(self, Self::Poisoned { .. }) } - + fn state_name(&self) -> &'static str { match self { Self::Idle => "Idle", @@ -68,7 +98,7 @@ impl StreamState for CryptoReaderState { // ============= CryptoReader ============= -/// Reader that decrypts data using AES-CTR with proper state machine +/// Reader that decrypts data using AES-CTR with proper state machine. /// /// This reader handles partial reads correctly by maintaining internal state /// and never losing any data that has been read from upstream. @@ -78,25 +108,24 @@ impl StreamState for CryptoReaderState { /// ┌──────────┐ read ┌──────────┐ /// │ Idle │ ------------> │ Yielding │ /// │ │ <------------ │ │ -/// └──────────┘ drained └──────────┘ +/// └──────────┘ drained └──────────┘ /// │ │ /// │ errors │ +/// ▼ ▼ /// ┌──────────────────────────────────────┐ /// │ Poisoned │ /// └──────────────────────────────────────┘ pub struct CryptoReader { - /// Upstream reader upstream: R, - /// AES-CTR decryptor decryptor: AesCtr, - /// Current state state: CryptoReaderState, - /// Internal read buffer for upstream reads + + /// Reserved for future coalescing optimizations. + #[allow(dead_code)] read_buf: BytesMut, } impl CryptoReader { - /// Create new crypto reader pub fn new(upstream: R, decryptor: AesCtr) -> Self { Self { upstream, @@ -105,45 +134,36 @@ impl CryptoReader { read_buf: BytesMut::with_capacity(DEFAULT_READ_CAPACITY), } } - - /// Get reference to upstream + pub fn get_ref(&self) -> &R { &self.upstream } - - /// Get mutable reference to upstream + pub fn get_mut(&mut self) -> &mut R { &mut self.upstream } - - /// Consume and return upstream + pub fn into_inner(self) -> R { self.upstream } - - /// Check if stream is in poisoned state + pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() } - - /// Get current state name (for debugging) + pub fn state_name(&self) -> &'static str { self.state.state_name() } - - /// Transition to poisoned state + fn poison(&mut self, error: io::Error) { self.state = CryptoReaderState::Poisoned { error: Some(error) }; } - - /// Take error from poisoned state + fn take_poison_error(&mut self) -> io::Error { match &mut self.state { - CryptoReaderState::Poisoned { error } => { - error.take().unwrap_or_else(|| { - io::Error::new(ErrorKind::Other, "stream previously poisoned") - }) - } + CryptoReaderState::Poisoned { error } => error.take().unwrap_or_else(|| { + io::Error::new(ErrorKind::Other, "stream previously poisoned") + }), _ => io::Error::new(ErrorKind::Other, "stream not poisoned"), } } @@ -156,67 +176,61 @@ impl AsyncRead for CryptoReader { buf: &mut ReadBuf<'_>, ) -> Poll> { let this = self.get_mut(); - + loop { match &mut this.state { - // Poisoned state - return error CryptoReaderState::Poisoned { .. } => { let err = this.take_poison_error(); return Poll::Ready(Err(err)); } - - // Have buffered data to yield + CryptoReaderState::Yielding { buffer } => { if buf.remaining() == 0 { return Poll::Ready(Ok(())); } - - // Copy as much as possible to output + let to_copy = buffer.remaining().min(buf.remaining()); let dst = buf.initialize_unfilled_to(to_copy); let copied = buffer.copy_to(dst); buf.advance(copied); - - // If buffer is drained, transition to Idle + if buffer.is_empty() { this.state = CryptoReaderState::Idle; } - + return Poll::Ready(Ok(())); } - - // Ready to read from upstream + CryptoReaderState::Idle => { - // If caller's buffer is empty, nothing to do if buf.remaining() == 0 { return Poll::Ready(Ok(())); } - - // Try to read directly into caller's buffer for zero-copy path - // We need to be careful: read into unfilled portion, then decrypt - let before_len = buf.filled().len(); - + + // Read directly into caller buffer, decrypt in-place for the bytes read. + let before = buf.filled().len(); + match Pin::new(&mut this.upstream).poll_read(cx, buf) { Poll::Pending => return Poll::Pending, - + Poll::Ready(Err(e)) => { this.poison(io::Error::new(e.kind(), e.to_string())); return Poll::Ready(Err(e)); } - + Poll::Ready(Ok(())) => { - let after_len = buf.filled().len(); - let bytes_read = after_len - before_len; - + let after = buf.filled().len(); + let bytes_read = after - before; + if bytes_read == 0 { // EOF return Poll::Ready(Ok(())); } - - // Decrypt the newly read data in-place + let filled = buf.filled_mut(); - this.decryptor.apply(&mut filled[before_len..after_len]); - + this.decryptor.apply(&mut filled[before..after]); + + trace!(bytes_read, state = this.state_name(), "CryptoReader decrypted chunk"); + return Poll::Ready(Ok(())); } } @@ -227,115 +241,197 @@ impl AsyncRead for CryptoReader { } impl CryptoReader { - /// Read and decrypt exactly n bytes - /// - /// This is a convenience method that accumulates data until - /// exactly n bytes are available. + /// Read and decrypt exactly n bytes. pub async fn read_exact_decrypt(&mut self, n: usize) -> Result { use tokio::io::AsyncReadExt; - + if self.is_poisoned() { return Err(self.take_poison_error()); } - + let mut result = BytesMut::with_capacity(n); - - // First drain any buffered data from Yielding state + + // Drain Yielding buffer if present (rare, kept for completeness) if let CryptoReaderState::Yielding { buffer } = &mut self.state { let to_take = buffer.remaining().min(n); let mut temp = vec![0u8; to_take]; buffer.copy_to(&mut temp); result.extend_from_slice(&temp); - + if buffer.is_empty() { self.state = CryptoReaderState::Idle; } } - - // Read remaining from upstream + while result.len() < n { let mut temp = vec![0u8; n - result.len()]; let read = self.read(&mut temp).await?; - + if read == 0 { return Err(io::Error::new( ErrorKind::UnexpectedEof, - format!("expected {} bytes, got {}", n, result.len()) + format!("expected {} bytes, got {}", n, result.len()), )); } - + result.extend_from_slice(&temp[..read]); } - + Ok(result.freeze()) } - - /// Read into internal buffer and return decrypted bytes - /// - /// Useful when you need the data as Bytes rather than copying to a slice. + + /// Read up to max_size bytes, returning decrypted bytes as Bytes. pub async fn read_decrypt(&mut self, max_size: usize) -> Result { use tokio::io::AsyncReadExt; - + if self.is_poisoned() { return Err(self.take_poison_error()); } - - // First check if we have buffered data + if let CryptoReaderState::Yielding { buffer } = &mut self.state { let to_take = buffer.remaining().min(max_size); let mut temp = vec![0u8; to_take]; buffer.copy_to(&mut temp); - + if buffer.is_empty() { self.state = CryptoReaderState::Idle; } - + return Ok(Bytes::from(temp)); } - - // Read from upstream + let mut temp = vec![0u8; max_size]; let read = self.read(&mut temp).await?; - + if read == 0 { return Ok(Bytes::new()); } - + temp.truncate(read); Ok(Bytes::from(temp)) } } +// ============= Pending Ciphertext ============= + +/// Pending ciphertext buffer with explicit position and strict max size. +/// +/// - append plaintext then encrypt appended range in-place - one-touch copy, no extra Vec +/// - move ciphertext from scratch into pending without copying +/// - explicit compaction behavior for long-lived connections +#[derive(Debug)] +struct PendingCiphertext { + buf: BytesMut, + pos: usize, + max_len: usize, +} + +impl PendingCiphertext { + fn new(max_len: usize) -> Self { + Self { + buf: BytesMut::with_capacity(16 * 1024), + pos: 0, + max_len, + } + } + + fn pending_len(&self) -> usize { + self.buf.len().saturating_sub(self.pos) + } + + fn is_empty(&self) -> bool { + self.pending_len() == 0 + } + + fn pending_slice(&self) -> &[u8] { + &self.buf[self.pos..] + } + + fn remaining_capacity(&self) -> usize { + self.max_len.saturating_sub(self.buf.len()) + } + + fn advance(&mut self, n: usize) { + self.pos = (self.pos + n).min(self.buf.len()); + + if self.pos == self.buf.len() { + self.buf.clear(); + self.pos = 0; + return; + } + + // Compact when a large prefix was consumed. + if self.pos >= 32 * 1024 { + let _ = self.buf.split_to(self.pos); + self.pos = 0; + } + } + + /// Replace the entire pending ciphertext by moving `src` in (swap, no copy). + /// + /// Precondition: src.len() <= max_len. + fn replace_with(&mut self, mut src: BytesMut) { + debug_assert!(src.len() <= self.max_len); + + self.buf.clear(); + self.pos = 0; + + // Swap: keep allocations hot and avoid copying bytes. + std::mem::swap(&mut self.buf, &mut src); + } + + /// Append plaintext and encrypt appended range in-place. + /// + /// This is the high-throughput buffering path: + /// - copy plaintext into pending buffer + /// - encrypt only the newly appended bytes + /// + /// CTR state advances by exactly plaintext.len(). + fn push_encrypted(&mut self, encryptor: &mut AesCtr, plaintext: &[u8]) -> Result<()> { + if plaintext.is_empty() { + return Ok(()); + } + + if plaintext.len() > self.remaining_capacity() { + return Err(io::Error::new( + ErrorKind::WouldBlock, + "pending ciphertext buffer is full", + )); + } + + let start = self.buf.len(); + self.buf.reserve(plaintext.len()); + self.buf.extend_from_slice(plaintext); + + encryptor.apply(&mut self.buf[start..]); + + Ok(()) + } +} + // ============= CryptoWriter State ============= -/// State machine states for CryptoWriter #[derive(Debug)] enum CryptoWriterState { - /// Ready to accept new data + /// No pending ciphertext buffered. Idle, - - /// Have pending encrypted data to flush - Flushing { - /// Buffer of encrypted data waiting to be written - pending: WriteBuffer, - }, - + + /// There is pending ciphertext to flush. + Flushing { pending: PendingCiphertext }, + /// Stream encountered an error and cannot be used - Poisoned { - /// The error that caused poisoning - error: Option, - }, + Poisoned { error: Option }, } impl StreamState for CryptoWriterState { fn is_terminal(&self) -> bool { matches!(self, Self::Poisoned { .. }) } - + fn is_poisoned(&self) -> bool { matches!(self, Self::Poisoned { .. }) } - + fn state_name(&self) -> &'static str { match self { Self::Idle => "Idle", @@ -347,154 +443,178 @@ impl StreamState for CryptoWriterState { // ============= CryptoWriter ============= -/// Writer that encrypts data using AES-CTR with proper state machine +/// Writer that encrypts data using AES-CTR with correct async semantics. /// -/// This writer handles partial writes correctly by: -/// - Maintaining internal state for pending data -/// - Returning honest byte counts (only what's actually written or safely buffered) -/// - Implementing backpressure when internal buffer is full -/// -/// # State Machine -/// -/// ┌──────────┐ write ┌──────────┐ -/// │ Idle │ ----------> │ Flushing │ -/// │ │ <---------- │ │ -/// └──────────┘ flushed └──────────┘ -/// │ │ -/// │ errors │ -/// ┌───────────────────────────────────┐ -/// │ Poisoned │ -/// └───────────────────────────────────┘ -/// -/// # Backpressure -/// -/// When the internal pending buffer exceeds `MAX_PENDING_WRITE`, the writer -/// will return `Poll::Pending` until some data has been flushed to upstream. +/// - CTR state advances exactly by the number of bytes we report as written +/// - If upstream blocks, ciphertext is buffered/bounded +/// - Backpressure is applied when buffer is full pub struct CryptoWriter { - /// Upstream writer upstream: W, - /// AES-CTR encryptor encryptor: AesCtr, - /// Current state state: CryptoWriterState, + + /// Scratch ciphertext for fast "write-through" path. + /// + /// Flow: + /// - encrypt plaintext into scratch + /// - try upstream write + /// - if Pending/partial: move remainder into pending without re-encrypting + scratch: BytesMut, } impl CryptoWriter { - /// Create new crypto writer pub fn new(upstream: W, encryptor: AesCtr) -> Self { Self { upstream, encryptor, state: CryptoWriterState::Idle, + scratch: BytesMut::with_capacity(16 * 1024), } } - - /// Get reference to upstream + pub fn get_ref(&self) -> &W { &self.upstream } - - /// Get mutable reference to upstream + pub fn get_mut(&mut self) -> &mut W { &mut self.upstream } - - /// Consume and return upstream + pub fn into_inner(self) -> W { self.upstream } - - /// Check if stream is in poisoned state + pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() } - - /// Get current state name (for debugging) + pub fn state_name(&self) -> &'static str { self.state.state_name() } - - /// Check if there's pending data to flush + pub fn has_pending(&self) -> bool { - matches!(&self.state, CryptoWriterState::Flushing { pending } if !pending.is_empty()) + matches!(self.state, CryptoWriterState::Flushing { .. }) } - - /// Get pending bytes count + pub fn pending_len(&self) -> usize { match &self.state { - CryptoWriterState::Flushing { pending } => pending.len(), + CryptoWriterState::Flushing { pending } => pending.pending_len(), _ => 0, } } - - /// Transition to poisoned state + fn poison(&mut self, error: io::Error) { self.state = CryptoWriterState::Poisoned { error: Some(error) }; } - - /// Take error from poisoned state + fn take_poison_error(&mut self) -> io::Error { match &mut self.state { - CryptoWriterState::Poisoned { error } => { - error.take().unwrap_or_else(|| { - io::Error::new(ErrorKind::Other, "stream previously poisoned") - }) - } + CryptoWriterState::Poisoned { error } => error.take().unwrap_or_else(|| { + io::Error::new(ErrorKind::Other, "stream previously poisoned") + }), _ => io::Error::new(ErrorKind::Other, "stream not poisoned"), } } + + /// Ensure we are in Flushing state and return mutable pending buffer. + fn ensure_pending<'a>(state: &'a mut CryptoWriterState) -> &'a mut PendingCiphertext { + if matches!(state, CryptoWriterState::Idle) { + *state = CryptoWriterState::Flushing { + pending: PendingCiphertext::new(MAX_PENDING_WRITE), + }; + } + + match state { + CryptoWriterState::Flushing { pending } => pending, + _ => unreachable!("ensure_pending guarantees Flushing state"), + } + } + + /// Select how many plaintext bytes can be accepted in buffering path + /// + /// Requirement: worst case - upstream pending, must buffer all ciphertext + /// for the accepted bytes + fn select_to_accept_for_buffering(state: &CryptoWriterState, buf_len: usize) -> usize { + if buf_len == 0 { + return 0; + } + + match state { + CryptoWriterState::Flushing { pending } => buf_len.min(pending.remaining_capacity()), + CryptoWriterState::Idle => buf_len.min(MAX_PENDING_WRITE), + CryptoWriterState::Poisoned { .. } => 0, + } + } + + /// Encrypt plaintext into scratch (CTR advances by plaintext.len()). + fn encrypt_into_scratch(encryptor: &mut AesCtr, scratch: &mut BytesMut, plaintext: &[u8]) { + scratch.clear(); + scratch.reserve(plaintext.len()); + scratch.extend_from_slice(plaintext); + encryptor.apply(&mut scratch[..]); + } } impl CryptoWriter { - /// Try to flush pending data to upstream + /// Flush as much pending ciphertext as possible /// - /// Returns: - /// - `Poll::Ready(Ok(true))` if all pending data was flushed - /// - `Poll::Ready(Ok(false))` if some data remains - /// - `Poll::Pending` if upstream would block - /// - `Poll::Ready(Err(_))` on error - fn poll_flush_pending(&mut self, cx: &mut Context<'_>) -> Poll> { + /// Returns + /// - Ready(Ok(())) if all pending is flushed or was none + /// - Pending if upstream would block + /// - Ready(Err(_)) on error + fn poll_flush_pending(&mut self, cx: &mut Context<'_>) -> Poll> { loop { match &mut self.state { - CryptoWriterState::Idle => { - return Poll::Ready(Ok(true)); - } - CryptoWriterState::Poisoned { .. } => { let err = self.take_poison_error(); return Poll::Ready(Err(err)); } - + + CryptoWriterState::Idle => return Poll::Ready(Ok(())), + CryptoWriterState::Flushing { pending } => { if pending.is_empty() { self.state = CryptoWriterState::Idle; - return Poll::Ready(Ok(true)); + return Poll::Ready(Ok(())); } - - let data = pending.pending(); + + let data = pending.pending_slice(); + match Pin::new(&mut self.upstream).poll_write(cx, data) { - Poll::Pending => return Poll::Pending, - + Poll::Pending => { + trace!( + pending_len = pending.pending_len(), + pending_cap = pending.remaining_capacity(), + "CryptoWriter: upstream Pending while flushing pending ciphertext" + ); + return Poll::Pending; + } + Poll::Ready(Err(e)) => { self.poison(io::Error::new(e.kind(), e.to_string())); return Poll::Ready(Err(e)); } - + Poll::Ready(Ok(0)) => { let err = io::Error::new( ErrorKind::WriteZero, - "upstream returned 0 bytes written" + "upstream returned 0 bytes written", ); - self.poison(err.into()); - return Poll::Ready(Err(io::Error::new( - ErrorKind::WriteZero, - "upstream returned 0 bytes written" - ))); + self.poison(io::Error::new(err.kind(), err.to_string())); + return Poll::Ready(Err(err)); } - + Poll::Ready(Ok(n)) => { pending.advance(n); - // Continue loop to check if fully flushed + + trace!( + flushed = n, + pending_left = pending.pending_len(), + "CryptoWriter: flushed pending ciphertext" + ); + + // continue loop to flush more + continue; } } } @@ -510,153 +630,176 @@ impl AsyncWrite for CryptoWriter { buf: &[u8], ) -> Poll> { let this = self.get_mut(); - - // Check for poisoned state - if let CryptoWriterState::Poisoned { .. } = &this.state { + + // Poisoned? + if matches!(this.state, CryptoWriterState::Poisoned { .. }) { let err = this.take_poison_error(); return Poll::Ready(Err(err)); } - - // Empty write is always successful + + // Empty write is always OK if buf.is_empty() { return Poll::Ready(Ok(0)); } - - // First, try to flush any pending data - match this.poll_flush_pending(cx) { - Poll::Pending => { - // Check backpressure - if this.pending_len() >= MAX_PENDING_WRITE { - // Too much pending, must wait - return Poll::Pending; + + // 1) If we have pending ciphertext, prioritize flushing it + // If upstream pending + // -> still accept some plaintext ONLY if we can buffer + // all ciphertext for the accepted portion - bounded + if matches!(this.state, CryptoWriterState::Flushing { .. }) { + match this.poll_flush_pending(cx) { + Poll::Ready(Ok(())) => { + // pending drained -> proceed + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => { + // Upstream blocked. Apply ideal backpressure + // - accept up to remaining pending capacity + // - if no capacity -> pending + let to_accept = + Self::select_to_accept_for_buffering(&this.state, buf.len()); + + if to_accept == 0 { + trace!( + buf_len = buf.len(), + pending_len = this.pending_len(), + "CryptoWriter backpressure: pending full and upstream Pending -> Pending" + ); + return Poll::Pending; + } + + let plaintext = &buf[..to_accept]; + + // Disjoint borrows: borrow encryptor and state separately via a match + let encryptor = &mut this.encryptor; + let pending = Self::ensure_pending(&mut this.state); + + // Should not WouldBlock because to_accept <= remaining_capacity + if let Err(e) = pending.push_encrypted(encryptor, plaintext) { + if e.kind() == ErrorKind::WouldBlock { + return Poll::Pending; + } + return Poll::Ready(Err(e)); + } + + trace!( + accepted = to_accept, + pending_len = pending.pending_len(), + pending_cap = pending.remaining_capacity(), + "CryptoWriter: upstream Pending, buffered ciphertext (accepted plaintext)" + ); + + return Poll::Ready(Ok(to_accept)); } - // Can buffer more, continue below - } - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Ready(Ok(_)) => { - // Flushed (possibly partially), continue } } - - // Encrypt the data - let mut encrypted = buf.to_vec(); - this.encryptor.apply(&mut encrypted); - - // Try to write directly to upstream first - match Pin::new(&mut this.upstream).poll_write(cx, &encrypted) { - Poll::Ready(Ok(n)) if n == encrypted.len() => { - // All data written directly - Poll::Ready(Ok(buf.len())) + + // 2) Fast path: pending empty -> write-through + debug_assert!(matches!(this.state, CryptoWriterState::Idle)); + + // Worst-case buffering requirement + // - If upstream becomes pending -> buffer full ciphertext for accepted bytes + // -> accept at most MAX_PENDING_WRITE per poll_write call + let to_accept = buf.len().min(MAX_PENDING_WRITE); + let plaintext = &buf[..to_accept]; + + Self::encrypt_into_scratch(&mut this.encryptor, &mut this.scratch, plaintext); + + match Pin::new(&mut this.upstream).poll_write(cx, &this.scratch) { + Poll::Pending => { + // Upstream blocked: buffer FULL ciphertext for accepted bytes. + // Move scratch into pending without copying. + let ciphertext = std::mem::take(&mut this.scratch); + + let pending = Self::ensure_pending(&mut this.state); + pending.replace_with(ciphertext); + + trace!( + accepted = to_accept, + pending_len = pending.pending_len(), + "CryptoWriter: write-through got Pending, buffered full ciphertext" + ); + + Poll::Ready(Ok(to_accept)) } - - Poll::Ready(Ok(n)) => { - // Partial write - buffer the rest - let remaining = &encrypted[n..]; - - // Ensure we're in Flushing state - let pending = match &mut this.state { - CryptoWriterState::Flushing { pending } => pending, - CryptoWriterState::Idle => { - this.state = CryptoWriterState::Flushing { - pending: WriteBuffer::with_max_size(MAX_PENDING_WRITE), - }; - match &mut this.state { - CryptoWriterState::Flushing { pending } => pending, - _ => unreachable!(), - } - } - CryptoWriterState::Poisoned { .. } => unreachable!(), - }; - - // Try to buffer remaining - if pending.remaining_capacity() >= remaining.len() { - pending.extend(remaining).expect("capacity checked"); - Poll::Ready(Ok(buf.len())) - } else { - // Not enough buffer space - report what we could write - // The caller will need to retry with the rest - let bytes_accepted = n + pending.remaining_capacity(); - if bytes_accepted > n { - let can_buffer = &encrypted[n..bytes_accepted]; - pending.extend(can_buffer).expect("capacity checked"); - } - Poll::Ready(Ok(bytes_accepted.min(buf.len()))) - } - } - + Poll::Ready(Err(e)) => { this.poison(io::Error::new(e.kind(), e.to_string())); Poll::Ready(Err(e)) } - - Poll::Pending => { - // Upstream would block - buffer the encrypted data - let pending = match &mut this.state { - CryptoWriterState::Flushing { pending } => pending, - CryptoWriterState::Idle => { - this.state = CryptoWriterState::Flushing { - pending: WriteBuffer::with_max_size(MAX_PENDING_WRITE), - }; - match &mut this.state { - CryptoWriterState::Flushing { pending } => pending, - _ => unreachable!(), - } - } - CryptoWriterState::Poisoned { .. } => unreachable!(), - }; - - // Check if we can buffer all - if pending.remaining_capacity() >= encrypted.len() { - pending.extend(&encrypted).expect("capacity checked"); - // Wake up to try flushing later - cx.waker().wake_by_ref(); - Poll::Ready(Ok(buf.len())) - } else if pending.remaining_capacity() > 0 { - // Partial buffer - let can_buffer = pending.remaining_capacity(); - pending.extend(&encrypted[..can_buffer]).expect("capacity checked"); - cx.waker().wake_by_ref(); - Poll::Ready(Ok(can_buffer)) - } else { - // No buffer space - backpressure - Poll::Pending + + Poll::Ready(Ok(0)) => { + let err = io::Error::new(ErrorKind::WriteZero, "upstream returned 0 bytes written"); + this.poison(io::Error::new(err.kind(), err.to_string())); + Poll::Ready(Err(err)) + } + + Poll::Ready(Ok(n)) => { + if n == this.scratch.len() { + trace!( + accepted = to_accept, + ciphertext_len = this.scratch.len(), + "CryptoWriter: write-through wrote full ciphertext directly" + ); + this.scratch.clear(); + return Poll::Ready(Ok(to_accept)); } + + // Partial upstream write of ciphertext: + // We accepted `to_accept` plaintext bytes, CTR already advanced for to_accept + // Must buffer the remainder ciphertext + warn!( + accepted = to_accept, + ciphertext_len = this.scratch.len(), + written_ciphertext = n, + "CryptoWriter: partial upstream write, buffering remainder" + ); + + // Split off remainder without copying + let remainder = this.scratch.split_off(n); + this.scratch.clear(); + + let pending = Self::ensure_pending(&mut this.state); + pending.replace_with(remainder); + + Poll::Ready(Ok(to_accept)) } } } - + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - - // First flush our pending buffer - match this.poll_flush_pending(cx)? { - Poll::Pending => return Poll::Pending, - Poll::Ready(false) => { - cx.waker().wake_by_ref(); - return Poll::Pending; - } - Poll::Ready(true) => {} + + if matches!(this.state, CryptoWriterState::Poisoned { .. }) { + let err = this.take_poison_error(); + return Poll::Ready(Err(err)); } - - // Then flush upstream + + match this.poll_flush_pending(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + } + Pin::new(&mut this.upstream).poll_flush(cx) } - + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - - // Try to flush pending data first (best effort) + + // Best-effort flush pending ciphertext before shutdown + // If upstream blocks, proceed to shutdown anyway match this.poll_flush_pending(cx) { Poll::Pending => { - // Continue with shutdown anyway after registering waker + debug!( + pending_len = this.pending_len(), + "CryptoWriter: shutdown with pending ciphertext (upstream Pending)" + ); } - Poll::Ready(Err(_)) => { - // Ignore flush errors during shutdown - } - Poll::Ready(Ok(_)) => {} + Poll::Ready(Err(_)) => {} + Poll::Ready(Ok(())) => {} } - - // Shutdown upstream + Pin::new(&mut this.upstream).poll_shutdown(cx) } } @@ -666,28 +809,24 @@ impl AsyncWrite for CryptoWriter { /// Passthrough stream for fast mode - no encryption/decryption /// /// Used when keys are set up so that client and Telegram use the same -/// encryption, allowing data to pass through without re-encryption. +/// encryption, allowing data to pass through without re-encryption pub struct PassthroughStream { inner: S, } impl PassthroughStream { - /// Create new passthrough stream pub fn new(inner: S) -> Self { Self { inner } } - - /// Get reference to inner stream + pub fn get_ref(&self) -> &S { &self.inner } - - /// Get mutable reference to inner stream + pub fn get_mut(&mut self) -> &mut S { &mut self.inner } - - /// Consume and return inner stream + pub fn into_inner(self) -> S { self.inner } @@ -711,443 +850,12 @@ impl AsyncWrite for PassthroughStream { ) -> Poll> { Pin::new(&mut self.inner).poll_write(cx, buf) } - + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.inner).poll_flush(cx) } - + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.inner).poll_shutdown(cx) } -} - -// ============= Tests ============= - -#[cfg(test)] -mod tests { - use super::*; - use std::collections::VecDeque; - use std::pin::Pin; - use std::task::{Context, Poll, Waker, RawWaker, RawWakerVTable}; - use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; - - // ============= Test Helpers ============= - - fn noop_waker() -> Waker { - const VTABLE: RawWakerVTable = RawWakerVTable::new( - |_| RawWaker::new(std::ptr::null(), &VTABLE), - |_| {}, - |_| {}, - |_| {}, - ); - unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) } - } - - /// Mock writer that simulates partial writes - struct PartialWriter { - /// Max bytes to accept per write - chunk_size: usize, - /// Collected data - data: Vec, - /// Number of writes performed - write_count: usize, - /// If true, return Pending on first write attempt - first_pending: bool, - /// Track if first call happened - first_call: bool, - } - - impl PartialWriter { - fn new(chunk_size: usize) -> Self { - Self { - chunk_size, - data: Vec::new(), - write_count: 0, - first_pending: false, - first_call: true, - } - } - - fn with_first_pending(mut self) -> Self { - self.first_pending = true; - self - } - } - - impl AsyncWrite for PartialWriter { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - if self.first_pending && self.first_call { - self.first_call = false; - cx.waker().wake_by_ref(); - return Poll::Pending; - } - - self.write_count += 1; - let to_write = buf.len().min(self.chunk_size); - self.data.extend_from_slice(&buf[..to_write]); - Poll::Ready(Ok(to_write)) - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - } - - /// Mock reader that returns data in chunks - struct ChunkedReader { - data: VecDeque, - chunk_size: usize, - } - - impl ChunkedReader { - fn new(data: &[u8], chunk_size: usize) -> Self { - Self { - data: data.iter().copied().collect(), - chunk_size, - } - } - } - - impl AsyncRead for ChunkedReader { - fn poll_read( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - if self.data.is_empty() { - return Poll::Ready(Ok(())); - } - - let to_read = self.chunk_size.min(self.data.len()).min(buf.remaining()); - for _ in 0..to_read { - if let Some(byte) = self.data.pop_front() { - buf.put_slice(&[byte]); - } - } - - Poll::Ready(Ok(())) - } - } - - // ============= CryptoReader Tests ============= - - #[tokio::test] - async fn test_crypto_reader_basic() { - let key = [0x42u8; 32]; - let iv = 12345u128; - - // Encrypt some data - let original = b"Hello, encrypted world!"; - let mut encryptor = AesCtr::new(&key, iv); - let encrypted = encryptor.encrypt(original); - - // Create reader - let reader = ChunkedReader::new(&encrypted, 100); - let decryptor = AesCtr::new(&key, iv); - let mut crypto_reader = CryptoReader::new(reader, decryptor); - - // Read and decrypt - let mut buf = vec![0u8; original.len()]; - crypto_reader.read_exact(&mut buf).await.unwrap(); - - assert_eq!(&buf, original); - } - - #[tokio::test] - async fn test_crypto_reader_chunked() { - let key = [0x42u8; 32]; - let iv = 12345u128; - - let original = b"This is a longer message that will be read in chunks"; - let mut encryptor = AesCtr::new(&key, iv); - let encrypted = encryptor.encrypt(original); - - // Read in very small chunks - let reader = ChunkedReader::new(&encrypted, 5); - let decryptor = AesCtr::new(&key, iv); - let mut crypto_reader = CryptoReader::new(reader, decryptor); - - let mut result = Vec::new(); - let mut buf = [0u8; 7]; // Read in chunks different from write chunks - - loop { - let n = crypto_reader.read(&mut buf).await.unwrap(); - if n == 0 { - break; - } - result.extend_from_slice(&buf[..n]); - } - - assert_eq!(&result, original); - } - - #[tokio::test] - async fn test_crypto_reader_read_exact_decrypt() { - let key = [0x42u8; 32]; - let iv = 12345u128; - - let original = b"Exact read test data!"; - let mut encryptor = AesCtr::new(&key, iv); - let encrypted = encryptor.encrypt(original); - - let reader = ChunkedReader::new(&encrypted, 3); // Small chunks - let decryptor = AesCtr::new(&key, iv); - let mut crypto_reader = CryptoReader::new(reader, decryptor); - - let result = crypto_reader.read_exact_decrypt(original.len()).await.unwrap(); - assert_eq!(&result[..], original); - } - - // ============= CryptoWriter Tests ============= - - #[test] - fn test_crypto_writer_basic_sync() { - let key = [0x42u8; 32]; - let iv = 12345u128; - - let mock_writer = PartialWriter::new(100); - let encryptor = AesCtr::new(&key, iv); - let mut crypto_writer = CryptoWriter::new(mock_writer, encryptor); - - let waker = noop_waker(); - let mut cx = Context::from_waker(&waker); - - let original = b"Hello, world!"; - - // Write - let result = Pin::new(&mut crypto_writer).poll_write(&mut cx, original); - assert!(matches!(result, Poll::Ready(Ok(13)))); - - // Verify encryption happened - let encrypted = &crypto_writer.upstream.data; - assert_eq!(encrypted.len(), original.len()); - assert_ne!(encrypted.as_slice(), original); // Should be encrypted - - // Decrypt and verify - let mut decryptor = AesCtr::new(&key, iv); - let mut decrypted = encrypted.clone(); - decryptor.apply(&mut decrypted); - assert_eq!(&decrypted, original); - } - - #[test] - fn test_crypto_writer_partial_write() { - let key = [0x42u8; 32]; - let iv = 12345u128; - - // Writer that only accepts 5 bytes at a time - let mock_writer = PartialWriter::new(5); - let encryptor = AesCtr::new(&key, iv); - let mut crypto_writer = CryptoWriter::new(mock_writer, encryptor); - - let waker = noop_waker(); - let mut cx = Context::from_waker(&waker); - - let original = b"This is a longer message!"; // 25 bytes - - // First write - should accept all 25 bytes (5 written, 20 buffered) - let result = Pin::new(&mut crypto_writer).poll_write(&mut cx, original); - assert!(matches!(result, Poll::Ready(Ok(25)))); - - // Should have pending data - assert!(crypto_writer.has_pending()); - - // Flush to drain pending - loop { - match Pin::new(&mut crypto_writer).poll_flush(&mut cx) { - Poll::Ready(Ok(())) => break, - Poll::Ready(Err(e)) => panic!("Flush error: {}", e), - Poll::Pending => continue, - } - } - - // All data should be written now - assert!(!crypto_writer.has_pending()); - assert_eq!(crypto_writer.upstream.data.len(), 25); - - // Verify decryption - let mut decryptor = AesCtr::new(&key, iv); - let mut decrypted = crypto_writer.upstream.data.clone(); - decryptor.apply(&mut decrypted); - assert_eq!(&decrypted, original); - } - - #[test] - fn test_crypto_writer_pending_on_first_write() { - let key = [0x42u8; 32]; - let iv = 12345u128; - - // Writer that returns Pending on first call - let mock_writer = PartialWriter::new(100).with_first_pending(); - let encryptor = AesCtr::new(&key, iv); - let mut crypto_writer = CryptoWriter::new(mock_writer, encryptor); - - let waker = noop_waker(); - let mut cx = Context::from_waker(&waker); - - let original = b"Test data"; - - // First write should buffer and return Ready (not Pending) - // because we have buffer space - let result = Pin::new(&mut crypto_writer).poll_write(&mut cx, original); - assert!(matches!(result, Poll::Ready(Ok(9)))); - - // Data should be buffered - assert!(crypto_writer.has_pending()); - - // Second poll_flush should succeed - loop { - match Pin::new(&mut crypto_writer).poll_flush(&mut cx) { - Poll::Ready(Ok(())) => break, - Poll::Ready(Err(e)) => panic!("Flush error: {}", e), - Poll::Pending => continue, - } - } - } - - #[tokio::test] - async fn test_crypto_stream_roundtrip() { - let key = [0u8; 32]; - let iv = 12345u128; - - let (client, server) = duplex(4096); - - let encryptor = AesCtr::new(&key, iv); - let decryptor = AesCtr::new(&key, iv); - - let mut writer = CryptoWriter::new(client, encryptor); - let mut reader = CryptoReader::new(server, decryptor); - - // Write - let original = b"Hello, encrypted world!"; - writer.write_all(original).await.unwrap(); - writer.flush().await.unwrap(); - - // Read - let mut buf = vec![0u8; original.len()]; - reader.read_exact(&mut buf).await.unwrap(); - - assert_eq!(&buf, original); - } - - #[tokio::test] - async fn test_crypto_stream_large_data() { - let key = [0x55u8; 32]; - let iv = 777u128; - - let (client, server) = duplex(1024); - - let encryptor = AesCtr::new(&key, iv); - let decryptor = AesCtr::new(&key, iv); - - let mut writer = CryptoWriter::new(client, encryptor); - let mut reader = CryptoReader::new(server, decryptor); - - // Large data - let original: Vec = (0..10000).map(|i| (i % 256) as u8).collect(); - - // Write in background - let write_data = original.clone(); - let write_handle = tokio::spawn(async move { - writer.write_all(&write_data).await.unwrap(); - writer.flush().await.unwrap(); - writer.shutdown().await.unwrap(); - }); - - // Read - let mut received = Vec::new(); - let mut buf = vec![0u8; 1024]; - loop { - match reader.read(&mut buf).await { - Ok(0) => break, - Ok(n) => received.extend_from_slice(&buf[..n]), - Err(e) => panic!("Read error: {}", e), - } - } - - write_handle.await.unwrap(); - - assert_eq!(received, original); - } - - #[tokio::test] - async fn test_crypto_writer_backpressure() { - let key = [0x42u8; 32]; - let iv = 12345u128; - - // Very small buffer duplex - let (client, _server) = duplex(64); - - let encryptor = AesCtr::new(&key, iv); - let mut writer = CryptoWriter::new(client, encryptor); - - // Try to write a lot of data - let large_data = vec![0u8; MAX_PENDING_WRITE + 1000]; - - // This should eventually block due to backpressure - // (duplex buffer full + our pending buffer full) - let write_result = tokio::time::timeout( - std::time::Duration::from_millis(100), - writer.write_all(&large_data) - ).await; - - // Should timeout because we can't write all data - assert!(write_result.is_err()); - } - - // ============= State Tests ============= - - #[test] - fn test_reader_state_transitions() { - let key = [0u8; 32]; - let iv = 0u128; - - let reader = ChunkedReader::new(&[], 10); - let decryptor = AesCtr::new(&key, iv); - let reader = CryptoReader::new(reader, decryptor); - - assert_eq!(reader.state_name(), "Idle"); - assert!(!reader.is_poisoned()); - } - - #[test] - fn test_writer_state_transitions() { - let key = [0u8; 32]; - let iv = 0u128; - - let writer = PartialWriter::new(10); - let encryptor = AesCtr::new(&key, iv); - let writer = CryptoWriter::new(writer, encryptor); - - assert_eq!(writer.state_name(), "Idle"); - assert!(!writer.is_poisoned()); - assert!(!writer.has_pending()); - } - - // ============= Passthrough Tests ============= - - #[tokio::test] - async fn test_passthrough_stream() { - let (client, server) = duplex(4096); - - let mut writer = PassthroughStream::new(client); - let mut reader = PassthroughStream::new(server); - - let data = b"No encryption here!"; - writer.write_all(data).await.unwrap(); - writer.flush().await.unwrap(); - - let mut buf = vec![0u8; data.len()]; - reader.read_exact(&mut buf).await.unwrap(); - - assert_eq!(&buf, data); - } } \ No newline at end of file