From ffe5a6cfb761d0649a689d3ccdd7971a6be10003 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Fri, 2 Jan 2026 01:17:56 +0300 Subject: [PATCH] Fake TLS Fixes for Async IO added more comments and schemas --- src/stream/tls_stream.rs | 1354 +++++++++++++++++++++++++++++++++----- 1 file changed, 1204 insertions(+), 150 deletions(-) diff --git a/src/stream/tls_stream.rs b/src/stream/tls_stream.rs index fbe2f5e..287fdb8 100644 --- a/src/stream/tls_stream.rs +++ b/src/stream/tls_stream.rs @@ -1,26 +1,207 @@ //! Fake TLS 1.3 stream wrappers +//! +//! This module provides stateful async stream wrappers that handle +//! TLS record framing with proper partial read/write handling. +//! +//! These are "fake" TLS streams - they wrap data in valid TLS 1.3 +//! Application Data records but don't perform actual TLS encryption. +//! The actual encryption is handled by the crypto layer underneath. +//! +//! Key design principles: +//! - Explicit state machines for all async operations +//! - Never lose data on partial reads +//! - Atomic TLS record formation for writes +//! - Proper handling of all TLS record types -use bytes::{Bytes, BytesMut}; -use std::io::{Error, ErrorKind, Result}; +use bytes::{Bytes, BytesMut, BufMut}; +use std::io::{self, Error, ErrorKind, Result}; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf}; + use crate::protocol::constants::{ TLS_VERSION, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, - MAX_TLS_CHUNK_SIZE, + TLS_RECORD_HANDSHAKE, TLS_RECORD_ALERT, MAX_TLS_RECORD_SIZE, }; -use parking_lot::Mutex; +use super::state::{StreamState, HeaderBuffer, YieldBuffer, WriteBuffer}; -/// Reader that unwraps TLS 1.3 records -pub struct FakeTlsReader { - upstream: R, - buffer: BytesMut, - pending_read: Option, +// ============= Constants ============= + +/// TLS record header size +const TLS_HEADER_SIZE: usize = 5; + +/// Maximum TLS record payload size (16KB as per TLS spec) +const MAX_TLS_PAYLOAD: usize = 16384; + +/// Maximum pending write buffer +const MAX_PENDING_WRITE: usize = 64 * 1024; + +// ============= TLS Record Types ============= + +/// Parsed TLS record header +#[derive(Debug, Clone, Copy)] +struct TlsRecordHeader { + /// Record type (0x17 = Application Data, 0x14 = Change Cipher, etc.) + record_type: u8, + /// TLS version bytes + version: [u8; 2], + /// Payload length + length: u16, } -struct PendingTlsRead { - record_type: u8, - remaining: usize, +impl TlsRecordHeader { + /// Parse header from 5 bytes + fn parse(header: &[u8; 5]) -> Option { + let record_type = header[0]; + let version = [header[1], header[2]]; + let length = u16::from_be_bytes([header[3], header[4]]); + + Some(Self { + record_type, + version, + length, + }) + } + + /// Validate the header + fn validate(&self) -> Result<()> { + // Check version (accept TLS 1.0 for ClientHello, TLS 1.2/1.3 for others) + if self.version != [0x03, 0x01] && self.version != TLS_VERSION { + return Err(Error::new( + ErrorKind::InvalidData, + format!("Invalid TLS version: {:02x?}", self.version), + )); + } + + // Check length + if self.length as usize > MAX_TLS_RECORD_SIZE { + return Err(Error::new( + ErrorKind::InvalidData, + format!("TLS record too large: {} bytes", self.length), + )); + } + + Ok(()) + } + + /// Check if this is an application data record + fn is_application_data(&self) -> bool { + self.record_type == TLS_RECORD_APPLICATION + } + + /// Check if this is a change cipher spec record (should be skipped) + fn is_change_cipher_spec(&self) -> bool { + self.record_type == TLS_RECORD_CHANGE_CIPHER + } + + /// Build header bytes + fn to_bytes(&self) -> [u8; 5] { + [ + self.record_type, + self.version[0], + self.version[1], + (self.length >> 8) as u8, + self.length as u8, + ] + } +} + +// ============= FakeTlsReader State ============= + +/// State machine states for FakeTlsReader +#[derive(Debug)] +enum TlsReaderState { + /// Ready to read a new TLS record + Idle, + + /// Reading the 5-byte TLS record header + ReadingHeader { + /// Header buffer (5 bytes) + header: HeaderBuffer, + }, + + /// Reading the TLS record body + ReadingBody { + /// Parsed record type + record_type: u8, + /// Total body length + length: usize, + /// Buffer for body data + buffer: BytesMut, + }, + + /// Have decrypted data ready to yield to caller + Yielding { + /// Buffer containing data to yield + buffer: YieldBuffer, + }, + + /// Stream encountered an error and cannot be used + Poisoned { + /// The error that caused poisoning + error: Option, + }, +} + +impl StreamState for TlsReaderState { + fn is_terminal(&self) -> bool { + matches!(self, Self::Poisoned { .. }) + } + + fn is_poisoned(&self) -> bool { + matches!(self, Self::Poisoned { .. }) + } + + fn state_name(&self) -> &'static str { + match self { + Self::Idle => "Idle", + Self::ReadingHeader { .. } => "ReadingHeader", + Self::ReadingBody { .. } => "ReadingBody", + Self::Yielding { .. } => "Yielding", + Self::Poisoned { .. } => "Poisoned", + } + } +} + +// ============= FakeTlsReader ============= + +/// Reader that unwraps TLS 1.3 records with proper state machine +/// +/// This reader handles partial reads correctly by maintaining internal state +/// and never losing any data that has been read from upstream. +/// +/// # State Machine +/// +/// ┌──────────┐ ┌───────────────┐ +/// │ Idle │ -----------------> │ ReadingHeader │ +/// └──────────┘ └───────┬───────┘ +/// ▲ │ +/// │ header complete +/// │ │ +/// │ │ +/// │ ┌───────────────┐ +/// │ skip record │ ReadingBody │ +/// │ <-------- (CCS) -------- │ │ +/// │ └───────┬───────┘ +/// │ │ +/// │ body complete +/// │ drained │ +/// │ <-----------------┐ │ +/// │ │ ┌───────────────┐ +/// │ └----- │ Yielding │ +/// │ └───────────────┘ +/// │ +/// │ errors /w any state +/// │ +/// ┌───────────────────────────────────────────────┐ +/// │ Poisoned │ +/// └───────────────────────────────────────────────┘ +/// +pub struct FakeTlsReader { + /// Upstream reader + upstream: R, + /// Current state + state: TlsReaderState, } impl FakeTlsReader { @@ -28,8 +209,7 @@ impl FakeTlsReader { pub fn new(upstream: R) -> Self { Self { upstream, - buffer: BytesMut::with_capacity(16384), - pending_read: None, + state: TlsReaderState::Idle, } } @@ -47,138 +227,404 @@ impl FakeTlsReader { pub fn into_inner(self) -> R { self.upstream } -} - -impl FakeTlsReader { - /// Read exactly n bytes through TLS layer - pub async fn read_exact(&mut self, n: usize) -> Result { - while self.buffer.len() < n { - let data = self.read_tls_record().await?; - if data.is_empty() { - return Err(Error::new(ErrorKind::UnexpectedEof, "Connection closed")); - } - self.buffer.extend_from_slice(&data); - } - - Ok(self.buffer.split_to(n).freeze()) + + /// Check if stream is in poisoned state + pub fn is_poisoned(&self) -> bool { + self.state.is_poisoned() } - /// Read a single TLS record - async fn read_tls_record(&mut self) -> Result> { - loop { - // Read TLS record header (5 bytes) - let mut header = [0u8; 5]; - self.upstream.read_exact(&mut header).await?; - - let record_type = header[0]; - let version = [header[1], header[2]]; - let length = u16::from_be_bytes([header[3], header[4]]) as usize; - - // Validate version - if version != TLS_VERSION { - return Err(Error::new( - ErrorKind::InvalidData, - format!("Invalid TLS version: {:02x?}", version), - )); - } - - // Read record body - let mut data = vec![0u8; length]; - self.upstream.read_exact(&mut data).await?; - - match record_type { - TLS_RECORD_CHANGE_CIPHER => continue, // Skip - TLS_RECORD_APPLICATION => return Ok(data), - _ => { - return Err(Error::new( - ErrorKind::InvalidData, - format!("Unexpected TLS record type: 0x{:02x}", record_type), - )); - } + /// Get current state name (for debugging) + pub fn state_name(&self) -> &'static str { + self.state.state_name() + } + + /// Transition to poisoned state + fn poison(&mut self, error: io::Error) { + self.state = TlsReaderState::Poisoned { error: Some(error) }; + } + + /// Take error from poisoned state + fn take_poison_error(&mut self) -> io::Error { + match &mut self.state { + TlsReaderState::Poisoned { error } => { + error.take().unwrap_or_else(|| { + io::Error::new(ErrorKind::Other, "stream previously poisoned") + }) } + _ => io::Error::new(ErrorKind::Other, "stream not poisoned"), } } } +/// Result of polling for header completion +enum HeaderPollResult { + /// Need more data + Pending, + /// EOF at record boundary (clean close) + Eof, + /// Header complete, parsed successfully + Complete(TlsRecordHeader), + /// Error occurred + Error(io::Error), +} + +/// Result of polling for body completion +enum BodyPollResult { + /// Need more data + Pending, + /// Body complete + Complete(Bytes), + /// Error occurred + Error(io::Error), +} + impl AsyncRead for FakeTlsReader { fn poll_read( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - // Drain buffer first - if !self.buffer.is_empty() { - let to_copy = self.buffer.len().min(buf.remaining()); - buf.put_slice(&self.buffer.split_to(to_copy)); - return Poll::Ready(Ok(())); - } + let this = self.get_mut(); - // We need to read a TLS record, but poll_read doesn't support async/await - // So we'll do a simplified version that reads header synchronously - - // Read header - let mut header = [0u8; 5]; - let mut header_buf = ReadBuf::new(&mut header); - - match Pin::new(&mut self.upstream).poll_read(cx, &mut header_buf) { - Poll::Ready(Ok(())) => { - if header_buf.filled().len() < 5 { - // Need more data - store what we have and return pending - // For simplicity, we'll just return empty + loop { + // Take ownership of state to avoid borrow conflicts + let state = std::mem::replace(&mut this.state, TlsReaderState::Idle); + + match state { + // Poisoned state - return error + TlsReaderState::Poisoned { error } => { + this.state = TlsReaderState::Poisoned { error: None }; + let err = error.unwrap_or_else(|| { + io::Error::new(ErrorKind::Other, "stream previously poisoned") + }); + return Poll::Ready(Err(err)); + } + + // Have buffered data to yield + TlsReaderState::Yielding { mut buffer } => { + if buf.remaining() == 0 { + this.state = TlsReaderState::Yielding { buffer }; + return Poll::Ready(Ok(())); + } + + // Copy as much as possible to output + let to_copy = buffer.remaining().min(buf.remaining()); + let dst = buf.initialize_unfilled_to(to_copy); + let copied = buffer.copy_to(dst); + buf.advance(copied); + + // If buffer is drained, transition to Idle + if buffer.is_empty() { + this.state = TlsReaderState::Idle; + } else { + this.state = TlsReaderState::Yielding { buffer }; + } + return Poll::Ready(Ok(())); } - } - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Pending => return Poll::Pending, - } - - let record_type = header[0]; - let length = u16::from_be_bytes([header[3], header[4]]) as usize; - - if record_type == TLS_RECORD_CHANGE_CIPHER { - // Skip this record, try again - cx.waker().wake_by_ref(); - return Poll::Pending; - } - - if record_type != TLS_RECORD_APPLICATION { - return Poll::Ready(Err(Error::new( - ErrorKind::InvalidData, - "Invalid TLS record type", - ))); - } - - // Read body - let mut body = vec![0u8; length]; - let mut body_buf = ReadBuf::new(&mut body); - - match Pin::new(&mut self.upstream).poll_read(cx, &mut body_buf) { - Poll::Ready(Ok(())) => { - let filled = body_buf.filled(); - let to_copy = filled.len().min(buf.remaining()); - buf.put_slice(&filled[..to_copy]); - if filled.len() > to_copy { - self.buffer.extend_from_slice(&filled[to_copy..]); + // Ready to read a new TLS record + TlsReaderState::Idle => { + if buf.remaining() == 0 { + this.state = TlsReaderState::Idle; + return Poll::Ready(Ok(())); + } + + // Start reading header + this.state = TlsReaderState::ReadingHeader { + header: HeaderBuffer::new(), + }; + // Continue to ReadingHeader } - Poll::Ready(Ok(())) + // Reading TLS record header + TlsReaderState::ReadingHeader { mut header } => { + // Poll to fill header + let result = poll_read_header(&mut this.upstream, cx, &mut header); + + match result { + HeaderPollResult::Pending => { + this.state = TlsReaderState::ReadingHeader { header }; + return Poll::Pending; + } + HeaderPollResult::Eof => { + this.state = TlsReaderState::Idle; + return Poll::Ready(Ok(())); + } + HeaderPollResult::Error(e) => { + this.poison(Error::new(e.kind(), e.to_string())); + return Poll::Ready(Err(e)); + } + HeaderPollResult::Complete(parsed) => { + // Validate header + if let Err(e) = parsed.validate() { + this.poison(Error::new(e.kind(), e.to_string())); + return Poll::Ready(Err(e)); + } + + let length = parsed.length as usize; + + // Transition to reading body + this.state = TlsReaderState::ReadingBody { + record_type: parsed.record_type, + length, + buffer: BytesMut::with_capacity(length), + }; + } + } + } + + // Reading TLS record body + TlsReaderState::ReadingBody { record_type, length, mut buffer } => { + let result = poll_read_body(&mut this.upstream, cx, &mut buffer, length); + + match result { + BodyPollResult::Pending => { + this.state = TlsReaderState::ReadingBody { record_type, length, buffer }; + return Poll::Pending; + } + BodyPollResult::Error(e) => { + this.poison(Error::new(e.kind(), e.to_string())); + return Poll::Ready(Err(e)); + } + BodyPollResult::Complete(data) => { + // Handle different record types + match record_type { + TLS_RECORD_CHANGE_CIPHER => { + // Skip Change Cipher Spec, read next record + this.state = TlsReaderState::Idle; + continue; + } + TLS_RECORD_APPLICATION => { + // Application data - yield to caller + if data.is_empty() { + this.state = TlsReaderState::Idle; + continue; + } + + this.state = TlsReaderState::Yielding { + buffer: YieldBuffer::new(data), + }; + // Continue to yield + } + TLS_RECORD_ALERT => { + // TLS Alert - treat as EOF + this.state = TlsReaderState::Idle; + return Poll::Ready(Ok(())); + } + TLS_RECORD_HANDSHAKE => { + let err = Error::new( + ErrorKind::InvalidData, + "unexpected TLS handshake record" + ); + this.poison(Error::new(err.kind(), err.to_string())); + return Poll::Ready(Err(err)); + } + _ => { + let err = Error::new( + ErrorKind::InvalidData, + format!("unknown TLS record type: 0x{:02x}", record_type) + ); + this.poison(Error::new(err.kind(), err.to_string())); + return Poll::Ready(Err(err)); + } + } + } + } + } } - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Pending => Poll::Pending, } } } -/// Writer that wraps data in TLS 1.3 records +/// Poll to read and fill header buffer (standalone function to avoid borrow issues) +fn poll_read_header( + upstream: &mut R, + cx: &mut Context<'_>, + header: &mut HeaderBuffer, +) -> HeaderPollResult { + while !header.is_complete() { + let unfilled = header.unfilled_mut(); + let mut read_buf = ReadBuf::new(unfilled); + + match Pin::new(&mut *upstream).poll_read(cx, &mut read_buf) { + Poll::Pending => return HeaderPollResult::Pending, + Poll::Ready(Err(e)) => return HeaderPollResult::Error(e), + Poll::Ready(Ok(())) => { + let n = read_buf.filled().len(); + if n == 0 { + // EOF + if header.as_slice().is_empty() { + return HeaderPollResult::Eof; + } else { + return HeaderPollResult::Error(Error::new( + ErrorKind::UnexpectedEof, + format!("unexpected EOF in TLS header (got {} of 5 bytes)", + header.as_slice().len()) + )); + } + } + header.advance(n); + } + } + } + + // Parse header + let header_bytes = *header.as_array(); + match TlsRecordHeader::parse(&header_bytes) { + Some(h) => HeaderPollResult::Complete(h), + None => HeaderPollResult::Error(Error::new( + ErrorKind::InvalidData, + "failed to parse TLS header" + )), + } +} + +/// Poll to read record body (standalone function to avoid borrow issues) +fn poll_read_body( + upstream: &mut R, + cx: &mut Context<'_>, + buffer: &mut BytesMut, + target_len: usize, +) -> BodyPollResult { + while buffer.len() < target_len { + let remaining = target_len - buffer.len(); + + // Read into a temporary buffer + let mut temp = vec![0u8; remaining.min(8192)]; + let mut read_buf = ReadBuf::new(&mut temp); + + match Pin::new(&mut *upstream).poll_read(cx, &mut read_buf) { + Poll::Pending => return BodyPollResult::Pending, + Poll::Ready(Err(e)) => return BodyPollResult::Error(e), + Poll::Ready(Ok(())) => { + let n = read_buf.filled().len(); + if n == 0 { + return BodyPollResult::Error(Error::new( + ErrorKind::UnexpectedEof, + format!("unexpected EOF in TLS body (got {} of {} bytes)", + buffer.len(), target_len) + )); + } + buffer.extend_from_slice(&temp[..n]); + } + } + } + + BodyPollResult::Complete(buffer.split().freeze()) +} + +impl FakeTlsReader { + /// Read exactly n bytes through TLS layer + /// + /// This is a convenience method that accumulates data across + /// multiple TLS records until exactly n bytes are available. + pub async fn read_exact(&mut self, n: usize) -> Result { + if self.is_poisoned() { + return Err(self.take_poison_error()); + } + + let mut result = BytesMut::with_capacity(n); + + while result.len() < n { + let mut buf = vec![0u8; n - result.len()]; + let read = AsyncReadExt::read(self, &mut buf).await?; + + if read == 0 { + return Err(Error::new( + ErrorKind::UnexpectedEof, + format!("expected {} bytes, got {}", n, result.len()) + )); + } + + result.extend_from_slice(&buf[..read]); + } + + Ok(result.freeze()) + } +} + +// ============= FakeTlsWriter State ============= + +/// State machine states for FakeTlsWriter +#[derive(Debug)] +enum TlsWriterState { + /// Ready to accept new data + Idle, + + /// Writing a complete TLS record + WritingRecord { + /// Complete record (header + body) to write + record: WriteBuffer, + /// Original payload size (for return value calculation) + payload_size: usize, + }, + + /// Stream encountered an error and cannot be used + Poisoned { + /// The error that caused poisoning + error: Option, + }, +} + +impl StreamState for TlsWriterState { + fn is_terminal(&self) -> bool { + matches!(self, Self::Poisoned { .. }) + } + + fn is_poisoned(&self) -> bool { + matches!(self, Self::Poisoned { .. }) + } + + fn state_name(&self) -> &'static str { + match self { + Self::Idle => "Idle", + Self::WritingRecord { .. } => "WritingRecord", + Self::Poisoned { .. } => "Poisoned", + } + } +} + +// ============= FakeTlsWriter ============= + +/// Writer that wraps data in TLS 1.3 records with proper state machine +/// +/// This writer handles partial writes correctly by: +/// - Building complete TLS records before writing +/// - Maintaining internal state for partial record writes +/// - Never splitting a record mid-write to upstream +/// +/// # State Machine +/// +/// ┌──────────┐ write ┌─────────────────┐ +/// │ Idle │ -------------> │ WritingRecord │ +/// │ │ <------------- │ │ +/// └──────────┘ complete └─────────────────┘ +/// │ │ +/// │ < errors > │ +/// │ │ +/// ┌─────────────────────────────────────────────┐ +/// │ Poisoned │ +/// └─────────────────────────────────────────────┘ +/// +/// # Record Formation +/// +/// Data is chunked into records of at most MAX_TLS_PAYLOAD bytes. +/// Each record has a 5-byte header prepended. pub struct FakeTlsWriter { + /// Upstream writer upstream: W, + /// Current state + state: TlsWriterState, } impl FakeTlsWriter { /// Create new fake TLS writer pub fn new(upstream: W) -> Self { - Self { upstream } + Self { + upstream, + state: TlsWriterState::Idle, + } } /// Get reference to upstream @@ -195,70 +641,535 @@ impl FakeTlsWriter { pub fn into_inner(self) -> W { self.upstream } + + /// Check if stream is in poisoned state + pub fn is_poisoned(&self) -> bool { + self.state.is_poisoned() + } + + /// Get current state name (for debugging) + pub fn state_name(&self) -> &'static str { + self.state.state_name() + } + + /// Check if there's a pending record to write + pub fn has_pending(&self) -> bool { + matches!(&self.state, TlsWriterState::WritingRecord { record, .. } if !record.is_empty()) + } + + /// Transition to poisoned state + fn poison(&mut self, error: io::Error) { + self.state = TlsWriterState::Poisoned { error: Some(error) }; + } + + /// Take error from poisoned state + fn take_poison_error(&mut self) -> io::Error { + match &mut self.state { + TlsWriterState::Poisoned { error } => { + error.take().unwrap_or_else(|| { + io::Error::new(ErrorKind::Other, "stream previously poisoned") + }) + } + _ => io::Error::new(ErrorKind::Other, "stream not poisoned"), + } + } + + /// Build a TLS Application Data record + fn build_record(data: &[u8]) -> BytesMut { + let header = TlsRecordHeader { + record_type: TLS_RECORD_APPLICATION, + version: TLS_VERSION, + length: data.len() as u16, + }; + + let mut record = BytesMut::with_capacity(TLS_HEADER_SIZE + data.len()); + record.extend_from_slice(&header.to_bytes()); + record.extend_from_slice(data); + record + } +} + +/// Result of flushing pending record +enum FlushResult { + /// All data flushed, returns payload size + Complete(usize), + /// Need to wait for upstream + Pending, + /// Error occurred + Error(io::Error), +} + +impl FakeTlsWriter { + /// Try to flush pending record to upstream (standalone logic) + fn poll_flush_record_inner( + upstream: &mut W, + cx: &mut Context<'_>, + record: &mut WriteBuffer, + ) -> FlushResult { + while !record.is_empty() { + let data = record.pending(); + match Pin::new(&mut *upstream).poll_write(cx, data) { + Poll::Pending => return FlushResult::Pending, + + Poll::Ready(Err(e)) => return FlushResult::Error(e), + + Poll::Ready(Ok(0)) => { + return FlushResult::Error(Error::new( + ErrorKind::WriteZero, + "upstream returned 0 bytes written" + )); + } + + Poll::Ready(Ok(n)) => { + record.advance(n); + } + } + } + + FlushResult::Complete(0) + } } impl AsyncWrite for FakeTlsWriter { fn poll_write( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - // Build TLS record - let chunk_size = buf.len().min(MAX_TLS_CHUNK_SIZE); - let chunk = &buf[..chunk_size]; + let this = self.get_mut(); - let mut record = Vec::with_capacity(5 + chunk_size); - record.push(TLS_RECORD_APPLICATION); - record.extend_from_slice(&TLS_VERSION); - record.push((chunk_size >> 8) as u8); - record.push(chunk_size as u8); - record.extend_from_slice(chunk); + // Take ownership of state + let state = std::mem::replace(&mut this.state, TlsWriterState::Idle); - match Pin::new(&mut self.upstream).poll_write(cx, &record) { - Poll::Ready(Ok(written)) => { - if written >= 5 { - Poll::Ready(Ok(written - 5)) - } else { - Poll::Ready(Ok(0)) + match state { + TlsWriterState::Poisoned { error } => { + this.state = TlsWriterState::Poisoned { error: None }; + let err = error.unwrap_or_else(|| { + Error::new(ErrorKind::Other, "stream previously poisoned") + }); + return Poll::Ready(Err(err)); + } + + TlsWriterState::WritingRecord { mut record, payload_size } => { + // Continue flushing existing record + match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) { + FlushResult::Pending => { + this.state = TlsWriterState::WritingRecord { record, payload_size }; + return Poll::Pending; + } + FlushResult::Error(e) => { + this.poison(Error::new(e.kind(), e.to_string())); + return Poll::Ready(Err(e)); + } + FlushResult::Complete(_) => { + this.state = TlsWriterState::Idle; + // Fall through to handle new write + } } } - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Pending => Poll::Pending, + + TlsWriterState::Idle => { + this.state = TlsWriterState::Idle; + } + } + + // Now in Idle state + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + + // Chunk to maximum TLS payload size + let chunk_size = buf.len().min(MAX_TLS_PAYLOAD); + let chunk = &buf[..chunk_size]; + + // Build the complete record + let record_data = Self::build_record(chunk); + + // Try to write directly first + match Pin::new(&mut this.upstream).poll_write(cx, &record_data) { + Poll::Ready(Ok(n)) if n == record_data.len() => { + // Complete record written + Poll::Ready(Ok(chunk_size)) + } + + Poll::Ready(Ok(n)) => { + // Partial write - buffer the rest + let mut write_buffer = WriteBuffer::with_max_size(MAX_PENDING_WRITE); + let _ = write_buffer.extend(&record_data[n..]); + + this.state = TlsWriterState::WritingRecord { + record: write_buffer, + payload_size: chunk_size, + }; + + // We've accepted chunk_size bytes from caller + Poll::Ready(Ok(chunk_size)) + } + + Poll::Ready(Err(e)) => { + this.poison(Error::new(e.kind(), e.to_string())); + Poll::Ready(Err(e)) + } + + Poll::Pending => { + // Buffer the entire record + let mut write_buffer = WriteBuffer::with_max_size(MAX_PENDING_WRITE); + let _ = write_buffer.extend(&record_data); + + this.state = TlsWriterState::WritingRecord { + record: write_buffer, + payload_size: chunk_size, + }; + + // Wake to try again + cx.waker().wake_by_ref(); + + // We've accepted chunk_size bytes from caller + Poll::Ready(Ok(chunk_size)) + } } } - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.upstream).poll_flush(cx) + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + // Take ownership of state + let state = std::mem::replace(&mut this.state, TlsWriterState::Idle); + + match state { + TlsWriterState::Poisoned { error } => { + this.state = TlsWriterState::Poisoned { error: None }; + let err = error.unwrap_or_else(|| { + Error::new(ErrorKind::Other, "stream previously poisoned") + }); + return Poll::Ready(Err(err)); + } + + TlsWriterState::WritingRecord { mut record, payload_size } => { + match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) { + FlushResult::Pending => { + this.state = TlsWriterState::WritingRecord { record, payload_size }; + return Poll::Pending; + } + FlushResult::Error(e) => { + this.poison(Error::new(e.kind(), e.to_string())); + return Poll::Ready(Err(e)); + } + FlushResult::Complete(_) => { + this.state = TlsWriterState::Idle; + } + } + } + + TlsWriterState::Idle => { + this.state = TlsWriterState::Idle; + } + } + + // Flush upstream + Pin::new(&mut this.upstream).poll_flush(cx) } - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.upstream).poll_shutdown(cx) + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + // Take ownership of state + let state = std::mem::replace(&mut this.state, TlsWriterState::Idle); + + match state { + TlsWriterState::WritingRecord { mut record, payload_size } => { + // Try to flush pending (best effort) + match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) { + FlushResult::Pending => { + // Can't complete flush, continue with shutdown anyway + this.state = TlsWriterState::Idle; + } + FlushResult::Error(_) => { + // Ignore errors during shutdown + this.state = TlsWriterState::Idle; + } + FlushResult::Complete(_) => { + this.state = TlsWriterState::Idle; + } + } + } + _ => { + this.state = TlsWriterState::Idle; + } + } + + // Shutdown upstream + Pin::new(&mut this.upstream).poll_shutdown(cx) } } impl FakeTlsWriter { /// Write all data wrapped in TLS records (async method) + /// + /// This convenience method handles chunking large data into + /// multiple TLS records automatically. pub async fn write_all_tls(&mut self, data: &[u8]) -> Result<()> { - for chunk in data.chunks(MAX_TLS_CHUNK_SIZE) { - let header = [ - TLS_RECORD_APPLICATION, - TLS_VERSION[0], - TLS_VERSION[1], - (chunk.len() >> 8) as u8, - chunk.len() as u8, - ]; + let mut written = 0; + while written < data.len() { + let chunk_size = (data.len() - written).min(MAX_TLS_PAYLOAD); + let chunk = &data[written..written + chunk_size]; - self.upstream.write_all(&header).await?; - self.upstream.write_all(chunk).await?; + AsyncWriteExt::write_all(self, chunk).await?; + written += chunk_size; } - Ok(()) + + self.flush().await } } +// ============= Tests ============= + #[cfg(test)] mod tests { use super::*; - use tokio::io::duplex; + use std::collections::VecDeque; + use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; + + // ============= Test Helpers ============= + + /// Build a valid TLS Application Data record + fn build_tls_record(data: &[u8]) -> Vec { + let mut record = vec![ + TLS_RECORD_APPLICATION, + TLS_VERSION[0], + TLS_VERSION[1], + (data.len() >> 8) as u8, + data.len() as u8, + ]; + record.extend_from_slice(data); + record + } + + /// Build a Change Cipher Spec record + fn build_ccs_record() -> Vec { + vec![ + TLS_RECORD_CHANGE_CIPHER, + TLS_VERSION[0], + TLS_VERSION[1], + 0x00, 0x01, // length = 1 + 0x01, // CCS byte + ] + } + + /// Mock reader that returns data in chunks + struct ChunkedReader { + data: VecDeque, + chunk_size: usize, + } + + impl ChunkedReader { + fn new(data: &[u8], chunk_size: usize) -> Self { + Self { + data: data.iter().copied().collect(), + chunk_size, + } + } + } + + impl AsyncRead for ChunkedReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if self.data.is_empty() { + return Poll::Ready(Ok(())); + } + + let to_read = self.chunk_size.min(self.data.len()).min(buf.remaining()); + for _ in 0..to_read { + if let Some(byte) = self.data.pop_front() { + buf.put_slice(&[byte]); + } + } + + Poll::Ready(Ok(())) + } + } + + // ============= FakeTlsReader Tests ============= + + #[tokio::test] + async fn test_tls_reader_single_record() { + let payload = b"Hello, TLS!"; + let record = build_tls_record(payload); + + let reader = ChunkedReader::new(&record, 100); + let mut tls_reader = FakeTlsReader::new(reader); + + let mut buf = vec![0u8; payload.len()]; + tls_reader.read_exact(&mut buf).await.unwrap(); + + assert_eq!(&buf, payload); + } + + #[tokio::test] + async fn test_tls_reader_multiple_records() { + let payload1 = b"First record"; + let payload2 = b"Second record"; + + let mut data = build_tls_record(payload1); + data.extend_from_slice(&build_tls_record(payload2)); + + let reader = ChunkedReader::new(&data, 100); + let mut tls_reader = FakeTlsReader::new(reader); + + let mut buf1 = vec![0u8; payload1.len()]; + tls_reader.read_exact(&mut buf1).await.unwrap(); + assert_eq!(&buf1, payload1); + + let mut buf2 = vec![0u8; payload2.len()]; + tls_reader.read_exact(&mut buf2).await.unwrap(); + assert_eq!(&buf2, payload2); + } + + #[tokio::test] + async fn test_tls_reader_partial_header() { + // Read header byte by byte + let payload = b"Test"; + let record = build_tls_record(payload); + + let reader = ChunkedReader::new(&record, 1); // 1 byte at a time! + let mut tls_reader = FakeTlsReader::new(reader); + + let mut buf = vec![0u8; payload.len()]; + tls_reader.read_exact(&mut buf).await.unwrap(); + + assert_eq!(&buf, payload); + } + + #[tokio::test] + async fn test_tls_reader_partial_body() { + let payload = b"This is a longer payload that will be read in parts"; + let record = build_tls_record(payload); + + let reader = ChunkedReader::new(&record, 7); // Awkward chunk size + let mut tls_reader = FakeTlsReader::new(reader); + + let mut buf = vec![0u8; payload.len()]; + tls_reader.read_exact(&mut buf).await.unwrap(); + + assert_eq!(&buf, payload); + } + + #[tokio::test] + async fn test_tls_reader_skip_ccs() { + // CCS record followed by application data + let mut data = build_ccs_record(); + let payload = b"After CCS"; + data.extend_from_slice(&build_tls_record(payload)); + + let reader = ChunkedReader::new(&data, 100); + let mut tls_reader = FakeTlsReader::new(reader); + + let mut buf = vec![0u8; payload.len()]; + tls_reader.read_exact(&mut buf).await.unwrap(); + + assert_eq!(&buf, payload); + } + + #[tokio::test] + async fn test_tls_reader_multiple_ccs() { + // Multiple CCS records + let mut data = build_ccs_record(); + data.extend_from_slice(&build_ccs_record()); + let payload = b"After multiple CCS"; + data.extend_from_slice(&build_tls_record(payload)); + + let reader = ChunkedReader::new(&data, 3); // Small chunks + let mut tls_reader = FakeTlsReader::new(reader); + + let mut buf = vec![0u8; payload.len()]; + tls_reader.read_exact(&mut buf).await.unwrap(); + + assert_eq!(&buf, payload); + } + + #[tokio::test] + async fn test_tls_reader_eof() { + let reader = ChunkedReader::new(&[], 100); + let mut tls_reader = FakeTlsReader::new(reader); + + let mut buf = vec![0u8; 10]; + let read = tls_reader.read(&mut buf).await.unwrap(); + + assert_eq!(read, 0); + } + + #[tokio::test] + async fn test_tls_reader_state_names() { + let reader = ChunkedReader::new(&[], 100); + let tls_reader = FakeTlsReader::new(reader); + + assert_eq!(tls_reader.state_name(), "Idle"); + assert!(!tls_reader.is_poisoned()); + } + + // ============= FakeTlsWriter Tests ============= + + #[tokio::test] + async fn test_tls_writer_single_write() { + let (client, mut server) = duplex(4096); + let mut writer = FakeTlsWriter::new(client); + + let payload = b"Hello, TLS!"; + writer.write_all(payload).await.unwrap(); + writer.flush().await.unwrap(); + + // Read the TLS record from server + let mut header = [0u8; 5]; + server.read_exact(&mut header).await.unwrap(); + + assert_eq!(header[0], TLS_RECORD_APPLICATION); + assert_eq!(&header[1..3], &TLS_VERSION); + + let length = u16::from_be_bytes([header[3], header[4]]) as usize; + assert_eq!(length, payload.len()); + + let mut body = vec![0u8; length]; + server.read_exact(&mut body).await.unwrap(); + assert_eq!(&body, payload); + } + + #[tokio::test] + async fn test_tls_writer_large_data_chunking() { + let (client, mut server) = duplex(65536); + let mut writer = FakeTlsWriter::new(client); + + // Write data larger than MAX_TLS_PAYLOAD + let payload: Vec = (0..20000).map(|i| (i % 256) as u8).collect(); + writer.write_all(&payload).await.unwrap(); + writer.flush().await.unwrap(); + + // Read back - should be multiple records + let mut received = Vec::new(); + let mut records_count = 0; + + while received.len() < payload.len() { + let mut header = [0u8; 5]; + if server.read_exact(&mut header).await.is_err() { + break; + } + + assert_eq!(header[0], TLS_RECORD_APPLICATION); + records_count += 1; + + let length = u16::from_be_bytes([header[3], header[4]]) as usize; + assert!(length <= MAX_TLS_PAYLOAD); + + let mut body = vec![0u8; length]; + server.read_exact(&mut body).await.unwrap(); + received.extend_from_slice(&body); + } + + assert_eq!(received, payload); + assert!(records_count >= 2); // Should have multiple records + } #[tokio::test] async fn test_tls_stream_roundtrip() { @@ -274,4 +1185,147 @@ mod tests { let received = reader.read_exact(original.len()).await.unwrap(); assert_eq!(&received[..], original); } + + #[tokio::test] + async fn test_tls_stream_roundtrip_large() { + let (client, server) = duplex(4096); + + let mut writer = FakeTlsWriter::new(client); + let mut reader = FakeTlsReader::new(server); + + let original: Vec = (0..50000).map(|i| (i % 256) as u8).collect(); + + // Write in background + let write_data = original.clone(); + let write_handle = tokio::spawn(async move { + writer.write_all_tls(&write_data).await.unwrap(); + writer.shutdown().await.unwrap(); + }); + + // Read + let mut received = Vec::new(); + let mut buf = vec![0u8; 1024]; + loop { + let n = reader.read(&mut buf).await.unwrap(); + if n == 0 { + break; + } + received.extend_from_slice(&buf[..n]); + } + + write_handle.await.unwrap(); + assert_eq!(received, original); + } + + #[tokio::test] + async fn test_tls_writer_state_names() { + let (client, _server) = duplex(4096); + let writer = FakeTlsWriter::new(client); + + assert_eq!(writer.state_name(), "Idle"); + assert!(!writer.is_poisoned()); + assert!(!writer.has_pending()); + } + + // ============= Error Handling Tests ============= + + #[tokio::test] + async fn test_tls_reader_invalid_version() { + let invalid_record = vec![ + TLS_RECORD_APPLICATION, + 0x04, 0x00, // Invalid version + 0x00, 0x05, // length = 5 + 0x01, 0x02, 0x03, 0x04, 0x05, + ]; + + let reader = ChunkedReader::new(&invalid_record, 100); + let mut tls_reader = FakeTlsReader::new(reader); + + let mut buf = vec![0u8; 5]; + let result = tls_reader.read(&mut buf).await; + + assert!(result.is_err()); + assert!(tls_reader.is_poisoned()); + } + + #[tokio::test] + async fn test_tls_reader_unexpected_eof_header() { + // Partial header + let partial = vec![TLS_RECORD_APPLICATION, 0x03]; + + let reader = ChunkedReader::new(&partial, 100); + let mut tls_reader = FakeTlsReader::new(reader); + + let mut buf = vec![0u8; 10]; + let result = tls_reader.read(&mut buf).await; + + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_tls_reader_unexpected_eof_body() { + // Valid header but truncated body + let mut record = vec![ + TLS_RECORD_APPLICATION, + TLS_VERSION[0], TLS_VERSION[1], + 0x00, 0x10, // length = 16 + ]; + record.extend_from_slice(&[0u8; 8]); // Only 8 bytes of body + + let reader = ChunkedReader::new(&record, 100); + let mut tls_reader = FakeTlsReader::new(reader); + + let mut buf = vec![0u8; 16]; + let result = tls_reader.read(&mut buf).await; + + assert!(result.is_err()); + } + + // ============= Header Parsing Tests ============= + + #[test] + fn test_tls_record_header_parse() { + let header = [0x17, 0x03, 0x03, 0x01, 0x00]; + let parsed = TlsRecordHeader::parse(&header).unwrap(); + + assert_eq!(parsed.record_type, TLS_RECORD_APPLICATION); + assert_eq!(parsed.version, TLS_VERSION); + assert_eq!(parsed.length, 256); + } + + #[test] + fn test_tls_record_header_validate() { + let valid = TlsRecordHeader { + record_type: TLS_RECORD_APPLICATION, + version: TLS_VERSION, + length: 100, + }; + assert!(valid.validate().is_ok()); + + let invalid_version = TlsRecordHeader { + record_type: TLS_RECORD_APPLICATION, + version: [0x04, 0x00], + length: 100, + }; + assert!(invalid_version.validate().is_err()); + + let too_large = TlsRecordHeader { + record_type: TLS_RECORD_APPLICATION, + version: TLS_VERSION, + length: 20000, + }; + assert!(too_large.validate().is_err()); + } + + #[test] + fn test_tls_record_header_to_bytes() { + let header = TlsRecordHeader { + record_type: TLS_RECORD_APPLICATION, + version: TLS_VERSION, + length: 0x1234, + }; + + let bytes = header.to_bytes(); + assert_eq!(bytes, [0x17, 0x03, 0x03, 0x12, 0x34]); + } } \ No newline at end of file