From 27ac32a901389f325810626ba7b18b96c606d971 Mon Sep 17 00:00:00 2001 From: brekotis <2003123e@gmail.com> Date: Mon, 12 Jan 2026 00:26:56 +0300 Subject: [PATCH] Fixes in TLS for iOS --- src/stream/tls_stream.rs | 570 +++++++++++++++++---------------------- 1 file changed, 243 insertions(+), 327 deletions(-) diff --git a/src/stream/tls_stream.rs b/src/stream/tls_stream.rs index 287fdb8..a4edf58 100644 --- a/src/stream/tls_stream.rs +++ b/src/stream/tls_stream.rs @@ -1,17 +1,36 @@ //! Fake TLS 1.3 stream wrappers //! -//! This module provides stateful async stream wrappers that handle -//! TLS record framing with proper partial read/write handling. +//! This module provides stateful async stream wrappers that handle TLS record +//! framing with proper partial read/write handling. //! -//! These are "fake" TLS streams - they wrap data in valid TLS 1.3 -//! Application Data records but don't perform actual TLS encryption. -//! The actual encryption is handled by the crypto layer underneath. +//! These are "fake" TLS streams: +//! - We wrap raw bytes into syntactically valid TLS 1.3 records (Application Data). +//! - We DO NOT perform real TLS handshake/encryption. +//! - Real crypto for MTProto is handled by the crypto layer underneath. +//! +//! Why do we need this? +//! Telegram MTProto proxy "FakeTLS" mode uses a TLS-looking outer layer for +//! domain fronting / traffic camouflage. iOS Telegram clients are known to +//! produce slightly different TLS record sizing patterns than Android/Desktop, +//! including records that exceed 16384 payload bytes by a small overhead. //! //! Key design principles: //! - Explicit state machines for all async operations //! - Never lose data on partial reads //! - Atomic TLS record formation for writes //! - Proper handling of all TLS record types +//! +//! Important nuance (Telegram FakeTLS): +//! - The TLS spec limits "plaintext fragments" to 2^14 (16384) bytes. +//! - However, the on-the-wire record length can exceed 16384 because TLS 1.3 +//! uses AEAD and can include tag/overhead/padding. +//! - Telegram FakeTLS clients (notably iOS) may send Application Data records +//! with length up to 16384 + 24 bytes. We accept that as MAX_TLS_CHUNK_SIZE. +//! +//! If you reject those (e.g. validate length <= 16384), you will see errors like: +//! "TLS record too large: 16408 bytes" +//! and uploads from iOS will break (media/file sending), while small traffic +//! may still work. use bytes::{Bytes, BytesMut, BufMut}; use std::io::{self, Error, ErrorKind, Result}; @@ -20,25 +39,29 @@ use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf}; use crate::protocol::constants::{ - TLS_VERSION, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, - TLS_RECORD_HANDSHAKE, TLS_RECORD_ALERT, MAX_TLS_RECORD_SIZE, + TLS_VERSION, + TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, + TLS_RECORD_HANDSHAKE, TLS_RECORD_ALERT, + MAX_TLS_CHUNK_SIZE, }; use super::state::{StreamState, HeaderBuffer, YieldBuffer, WriteBuffer}; // ============= Constants ============= -/// TLS record header size +/// TLS record header size (type + version + length) const TLS_HEADER_SIZE: usize = 5; -/// Maximum TLS record payload size (16KB as per TLS spec) +/// Maximum TLS fragment size per spec (plaintext fragment). +/// We use this for *outgoing* chunking, because we build plain ApplicationData records. const MAX_TLS_PAYLOAD: usize = 16384; -/// Maximum pending write buffer +/// Maximum pending write buffer for one record remainder. +/// Note: we never queue unlimited amount of data here; state holds at most one record. const MAX_PENDING_WRITE: usize = 64 * 1024; // ============= TLS Record Types ============= -/// Parsed TLS record header +/// Parsed TLS record header (5 bytes) #[derive(Debug, Clone, Copy)] struct TlsRecordHeader { /// Record type (0x17 = Application Data, 0x14 = Change Cipher, etc.) @@ -50,50 +73,64 @@ struct TlsRecordHeader { } impl TlsRecordHeader { - /// Parse header from 5 bytes + /// Parse header from exactly 5 bytes. + /// + /// This currently never returns None, but is kept as Option to allow future + /// stricter parsing rules without changing callers. fn parse(header: &[u8; 5]) -> Option { let record_type = header[0]; let version = [header[1], header[2]]; let length = u16::from_be_bytes([header[3], header[4]]); - - Some(Self { - record_type, - version, - length, - }) + Some(Self { record_type, version, length }) } - - /// Validate the header + + /// Validate the header. + /// + /// Nuances: + /// - We accept TLS 1.0 header version for ClientHello-like records (0x03 0x01), + /// and TLS 1.2/1.3 style version bytes for the rest (we use TLS_VERSION = 0x03 0x03). + /// - For Application Data, Telegram FakeTLS may send payload length up to + /// MAX_TLS_CHUNK_SIZE (16384 + 24). + /// - For other record types we keep stricter bounds to avoid memory abuse. fn validate(&self) -> Result<()> { - // Check version (accept TLS 1.0 for ClientHello, TLS 1.2/1.3 for others) + // Version: accept TLS 1.0 header (ClientHello quirk) and TLS_VERSION (0x0303). if self.version != [0x03, 0x01] && self.version != TLS_VERSION { return Err(Error::new( ErrorKind::InvalidData, format!("Invalid TLS version: {:02x?}", self.version), )); } - - // Check length - if self.length as usize > MAX_TLS_RECORD_SIZE { - return Err(Error::new( - ErrorKind::InvalidData, - format!("TLS record too large: {} bytes", self.length), - )); + + let len = self.length as usize; + + // Length checks depend on record type. + // Telegram FakeTLS: ApplicationData length may be 16384 + 24. + match self.record_type { + TLS_RECORD_APPLICATION => { + if len > MAX_TLS_CHUNK_SIZE { + return Err(Error::new( + ErrorKind::InvalidData, + format!("TLS record too large: {} bytes (max {})", len, MAX_TLS_CHUNK_SIZE), + )); + } + } + + // ChangeCipherSpec/Alert/Handshake should never be that large for our usage + // (post-handshake we don't expect Handshake at all). + // Keep strict to reduce attack surface. + _ => { + if len > MAX_TLS_PAYLOAD { + return Err(Error::new( + ErrorKind::InvalidData, + format!("TLS control record too large: {} bytes (max {})", len, MAX_TLS_PAYLOAD), + )); + } + } } - + Ok(()) } - - /// Check if this is an application data record - fn is_application_data(&self) -> bool { - self.record_type == TLS_RECORD_APPLICATION - } - - /// Check if this is a change cipher spec record (should be skipped) - fn is_change_cipher_spec(&self) -> bool { - self.record_type == TLS_RECORD_CHANGE_CIPHER - } - + /// Build header bytes fn to_bytes(&self) -> [u8; 5] { [ @@ -113,32 +150,27 @@ impl TlsRecordHeader { enum TlsReaderState { /// Ready to read a new TLS record Idle, - + /// Reading the 5-byte TLS record header ReadingHeader { /// Header buffer (5 bytes) header: HeaderBuffer, }, - - /// Reading the TLS record body + + /// Reading the TLS record body (payload) ReadingBody { - /// Parsed record type record_type: u8, - /// Total body length length: usize, - /// Buffer for body data buffer: BytesMut, }, - - /// Have decrypted data ready to yield to caller + + /// Have buffered data ready to yield to caller Yielding { - /// Buffer containing data to yield buffer: YieldBuffer, }, - + /// Stream encountered an error and cannot be used Poisoned { - /// The error that caused poisoning error: Option, }, } @@ -147,11 +179,11 @@ impl StreamState for TlsReaderState { fn is_terminal(&self) -> bool { matches!(self, Self::Poisoned { .. }) } - + fn is_poisoned(&self) -> bool { matches!(self, Self::Poisoned { .. }) } - + fn state_name(&self) -> &'static str { match self { Self::Idle => "Idle", @@ -165,12 +197,13 @@ impl StreamState for TlsReaderState { // ============= FakeTlsReader ============= -/// Reader that unwraps TLS 1.3 records with proper state machine +/// Reader that unwraps TLS records (FakeTLS). /// -/// This reader handles partial reads correctly by maintaining internal state -/// and never losing any data that has been read from upstream. +/// This wrapper is responsible ONLY for TLS record framing and skipping +/// non-data records (like CCS). It does not decrypt TLS: payload bytes are passed +/// as-is to upper layers (crypto stream). /// -/// # State Machine +/// State machine overview: /// /// ┌──────────┐ ┌───────────────┐ /// │ Idle │ -----------------> │ ReadingHeader │ @@ -178,103 +211,69 @@ impl StreamState for TlsReaderState { /// ▲ │ /// │ header complete /// │ │ -/// │ │ +/// │ ▼ /// │ ┌───────────────┐ /// │ skip record │ ReadingBody │ /// │ <-------- (CCS) -------- │ │ /// │ └───────┬───────┘ /// │ │ /// │ body complete -/// │ drained │ -/// │ <-----------------┐ │ -/// │ │ ┌───────────────┐ -/// │ └----- │ Yielding │ +/// │ ▼ +/// │ ┌───────────────┐ +/// │ │ Yielding │ /// │ └───────────────┘ /// │ -/// │ errors /w any state -/// │ +/// │ errors / w any state +/// ▼ /// ┌───────────────────────────────────────────────┐ /// │ Poisoned │ /// └───────────────────────────────────────────────┘ /// +/// NOTE: We must correctly handle partial reads from upstream: +/// - do not assume header arrives in one poll +/// - do not assume body arrives in one poll +/// - never lose already-read bytes pub struct FakeTlsReader { - /// Upstream reader upstream: R, - /// Current state state: TlsReaderState, } impl FakeTlsReader { - /// Create new fake TLS reader pub fn new(upstream: R) -> Self { - Self { - upstream, - state: TlsReaderState::Idle, - } + Self { upstream, state: TlsReaderState::Idle } } - - /// Get reference to upstream - pub fn get_ref(&self) -> &R { - &self.upstream - } - - /// Get mutable reference to upstream - pub fn get_mut(&mut self) -> &mut R { - &mut self.upstream - } - - /// Consume and return upstream - pub fn into_inner(self) -> R { - self.upstream - } - - /// Check if stream is in poisoned state - pub fn is_poisoned(&self) -> bool { - self.state.is_poisoned() - } - - /// Get current state name (for debugging) - pub fn state_name(&self) -> &'static str { - self.state.state_name() - } - - /// Transition to poisoned state + + pub fn get_ref(&self) -> &R { &self.upstream } + pub fn get_mut(&mut self) -> &mut R { &mut self.upstream } + pub fn into_inner(self) -> R { self.upstream } + + pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() } + pub fn state_name(&self) -> &'static str { self.state.state_name() } + fn poison(&mut self, error: io::Error) { self.state = TlsReaderState::Poisoned { error: Some(error) }; } - - /// Take error from poisoned state + fn take_poison_error(&mut self) -> io::Error { match &mut self.state { - TlsReaderState::Poisoned { error } => { - error.take().unwrap_or_else(|| { - io::Error::new(ErrorKind::Other, "stream previously poisoned") - }) - } + TlsReaderState::Poisoned { error } => error.take().unwrap_or_else(|| { + io::Error::new(ErrorKind::Other, "stream previously poisoned") + }), _ => io::Error::new(ErrorKind::Other, "stream not poisoned"), } } } -/// Result of polling for header completion enum HeaderPollResult { - /// Need more data Pending, - /// EOF at record boundary (clean close) Eof, - /// Header complete, parsed successfully Complete(TlsRecordHeader), - /// Error occurred Error(io::Error), } -/// Result of polling for body completion enum BodyPollResult { - /// Need more data Pending, - /// Body complete Complete(Bytes), - /// Error occurred Error(io::Error), } @@ -285,13 +284,13 @@ impl AsyncRead for FakeTlsReader { buf: &mut ReadBuf<'_>, ) -> Poll> { let this = self.get_mut(); - + loop { // Take ownership of state to avoid borrow conflicts let state = std::mem::replace(&mut this.state, TlsReaderState::Idle); - + match state { - // Poisoned state - return error + // Poisoned state: always return the stored error TlsReaderState::Poisoned { error } => { this.state = TlsReaderState::Poisoned { error: None }; let err = error.unwrap_or_else(|| { @@ -299,55 +298,52 @@ impl AsyncRead for FakeTlsReader { }); return Poll::Ready(Err(err)); } - - // Have buffered data to yield + + // Yield buffered plaintext to caller TlsReaderState::Yielding { mut buffer } => { if buf.remaining() == 0 { this.state = TlsReaderState::Yielding { buffer }; return Poll::Ready(Ok(())); } - - // Copy as much as possible to output + let to_copy = buffer.remaining().min(buf.remaining()); let dst = buf.initialize_unfilled_to(to_copy); let copied = buffer.copy_to(dst); buf.advance(copied); - - // If buffer is drained, transition to Idle + if buffer.is_empty() { this.state = TlsReaderState::Idle; } else { this.state = TlsReaderState::Yielding { buffer }; } - + return Poll::Ready(Ok(())); } - - // Ready to read a new TLS record + + // Start reading new record TlsReaderState::Idle => { if buf.remaining() == 0 { this.state = TlsReaderState::Idle; return Poll::Ready(Ok(())); } - - // Start reading header + this.state = TlsReaderState::ReadingHeader { header: HeaderBuffer::new(), }; - // Continue to ReadingHeader + // loop continues and will handle ReadingHeader } - - // Reading TLS record header + + // Read TLS header (5 bytes) TlsReaderState::ReadingHeader { mut header } => { - // Poll to fill header let result = poll_read_header(&mut this.upstream, cx, &mut header); - + match result { HeaderPollResult::Pending => { this.state = TlsReaderState::ReadingHeader { header }; return Poll::Pending; } HeaderPollResult::Eof => { + // Clean EOF at record boundary this.state = TlsReaderState::Idle; return Poll::Ready(Ok(())); } @@ -356,15 +352,12 @@ impl AsyncRead for FakeTlsReader { return Poll::Ready(Err(e)); } HeaderPollResult::Complete(parsed) => { - // Validate header if let Err(e) = parsed.validate() { this.poison(Error::new(e.kind(), e.to_string())); return Poll::Ready(Err(e)); } - + let length = parsed.length as usize; - - // Transition to reading body this.state = TlsReaderState::ReadingBody { record_type: parsed.record_type, length, @@ -373,11 +366,11 @@ impl AsyncRead for FakeTlsReader { } } } - - // Reading TLS record body + + // Read TLS payload TlsReaderState::ReadingBody { record_type, length, mut buffer } => { let result = poll_read_body(&mut this.upstream, cx, &mut buffer, length); - + match result { BodyPollResult::Pending => { this.state = TlsReaderState::ReadingBody { record_type, length, buffer }; @@ -388,42 +381,43 @@ impl AsyncRead for FakeTlsReader { return Poll::Ready(Err(e)); } BodyPollResult::Complete(data) => { - // Handle different record types match record_type { TLS_RECORD_CHANGE_CIPHER => { - // Skip Change Cipher Spec, read next record + // CCS is expected in some clients, ignore it. this.state = TlsReaderState::Idle; continue; } + TLS_RECORD_APPLICATION => { - // Application data - yield to caller + // This is what we actually want. if data.is_empty() { this.state = TlsReaderState::Idle; continue; } - + this.state = TlsReaderState::Yielding { buffer: YieldBuffer::new(data), }; - // Continue to yield + // loop continues and will yield immediately } + TLS_RECORD_ALERT => { - // TLS Alert - treat as EOF + // Treat TLS alert as EOF-like termination. this.state = TlsReaderState::Idle; return Poll::Ready(Ok(())); } + TLS_RECORD_HANDSHAKE => { - let err = Error::new( - ErrorKind::InvalidData, - "unexpected TLS handshake record" - ); + // After FakeTLS handshake is done, we do not expect any Handshake records. + let err = Error::new(ErrorKind::InvalidData, "unexpected TLS handshake record"); this.poison(Error::new(err.kind(), err.to_string())); return Poll::Ready(Err(err)); } + _ => { let err = Error::new( ErrorKind::InvalidData, - format!("unknown TLS record type: 0x{:02x}", record_type) + format!("unknown TLS record type: 0x{:02x}", record_type), ); this.poison(Error::new(err.kind(), err.to_string())); return Poll::Ready(Err(err)); @@ -446,7 +440,7 @@ fn poll_read_header( while !header.is_complete() { let unfilled = header.unfilled_mut(); let mut read_buf = ReadBuf::new(unfilled); - + match Pin::new(&mut *upstream).poll_read(cx, &mut read_buf) { Poll::Pending => return HeaderPollResult::Pending, Poll::Ready(Err(e)) => return HeaderPollResult::Error(e), @@ -459,8 +453,10 @@ fn poll_read_header( } else { return HeaderPollResult::Error(Error::new( ErrorKind::UnexpectedEof, - format!("unexpected EOF in TLS header (got {} of 5 bytes)", - header.as_slice().len()) + format!( + "unexpected EOF in TLS header (got {} of 5 bytes)", + header.as_slice().len() + ), )); } } @@ -468,15 +464,11 @@ fn poll_read_header( } } } - - // Parse header + let header_bytes = *header.as_array(); match TlsRecordHeader::parse(&header_bytes) { Some(h) => HeaderPollResult::Complete(h), - None => HeaderPollResult::Error(Error::new( - ErrorKind::InvalidData, - "failed to parse TLS header" - )), + None => HeaderPollResult::Error(Error::new(ErrorKind::InvalidData, "failed to parse TLS header")), } } @@ -487,13 +479,15 @@ fn poll_read_body( buffer: &mut BytesMut, target_len: usize, ) -> BodyPollResult { + // NOTE: This implementation uses a temporary Vec to avoid tricky borrow/lifetime + // issues with BytesMut spare capacity and ReadBuf across polls. + // It's safe and correct; optimization is possible if needed. while buffer.len() < target_len { let remaining = target_len - buffer.len(); - - // Read into a temporary buffer + let mut temp = vec![0u8; remaining.min(8192)]; let mut read_buf = ReadBuf::new(&mut temp); - + match Pin::new(&mut *upstream).poll_read(cx, &mut read_buf) { Poll::Pending => return BodyPollResult::Pending, Poll::Ready(Err(e)) => return BodyPollResult::Error(e), @@ -502,67 +496,65 @@ fn poll_read_body( if n == 0 { return BodyPollResult::Error(Error::new( ErrorKind::UnexpectedEof, - format!("unexpected EOF in TLS body (got {} of {} bytes)", - buffer.len(), target_len) + format!( + "unexpected EOF in TLS body (got {} of {} bytes)", + buffer.len(), + target_len + ), )); } buffer.extend_from_slice(&temp[..n]); } } } - + BodyPollResult::Complete(buffer.split().freeze()) } impl FakeTlsReader { - /// Read exactly n bytes through TLS layer + /// Read exactly n bytes through TLS layer. /// - /// This is a convenience method that accumulates data across - /// multiple TLS records until exactly n bytes are available. + /// This accumulates data across multiple TLS ApplicationData records. pub async fn read_exact(&mut self, n: usize) -> Result { if self.is_poisoned() { return Err(self.take_poison_error()); } - + let mut result = BytesMut::with_capacity(n); - + while result.len() < n { let mut buf = vec![0u8; n - result.len()]; let read = AsyncReadExt::read(self, &mut buf).await?; - + if read == 0 { return Err(Error::new( ErrorKind::UnexpectedEof, - format!("expected {} bytes, got {}", n, result.len()) + format!("expected {} bytes, got {}", n, result.len()), )); } - + result.extend_from_slice(&buf[..read]); } - + Ok(result.freeze()) } } // ============= FakeTlsWriter State ============= -/// State machine states for FakeTlsWriter #[derive(Debug)] enum TlsWriterState { /// Ready to accept new data Idle, - - /// Writing a complete TLS record + + /// Writing a complete TLS record (header + body), possibly partially WritingRecord { - /// Complete record (header + body) to write record: WriteBuffer, - /// Original payload size (for return value calculation) payload_size: usize, }, - + /// Stream encountered an error and cannot be used Poisoned { - /// The error that caused poisoning error: Option, }, } @@ -571,11 +563,11 @@ impl StreamState for TlsWriterState { fn is_terminal(&self) -> bool { matches!(self, Self::Poisoned { .. }) } - + fn is_poisoned(&self) -> bool { matches!(self, Self::Poisoned { .. }) } - + fn state_name(&self) -> &'static str { match self { Self::Idle => "Idle", @@ -587,101 +579,53 @@ impl StreamState for TlsWriterState { // ============= FakeTlsWriter ============= -/// Writer that wraps data in TLS 1.3 records with proper state machine +/// Writer that wraps bytes into TLS 1.3 Application Data records. /// -/// This writer handles partial writes correctly by: -/// - Building complete TLS records before writing -/// - Maintaining internal state for partial record writes -/// - Never splitting a record mid-write to upstream -/// -/// # State Machine -/// -/// ┌──────────┐ write ┌─────────────────┐ -/// │ Idle │ -------------> │ WritingRecord │ -/// │ │ <------------- │ │ -/// └──────────┘ complete └─────────────────┘ -/// │ │ -/// │ < errors > │ -/// │ │ -/// ┌─────────────────────────────────────────────┐ -/// │ Poisoned │ -/// └─────────────────────────────────────────────┘ -/// -/// # Record Formation -/// -/// Data is chunked into records of at most MAX_TLS_PAYLOAD bytes. -/// Each record has a 5-byte header prepended. +/// We chunk outgoing data into records of <= 16384 payload bytes (MAX_TLS_PAYLOAD). +/// We do not try to mimic AEAD overhead on the wire; Telegram clients accept it. +/// If you want to be more camouflage-accurate later, you could add optional padding +/// to produce records sized closer to MAX_TLS_CHUNK_SIZE. pub struct FakeTlsWriter { - /// Upstream writer upstream: W, - /// Current state state: TlsWriterState, } impl FakeTlsWriter { - /// Create new fake TLS writer pub fn new(upstream: W) -> Self { - Self { - upstream, - state: TlsWriterState::Idle, - } + Self { upstream, state: TlsWriterState::Idle } } - - /// Get reference to upstream - pub fn get_ref(&self) -> &W { - &self.upstream - } - - /// Get mutable reference to upstream - pub fn get_mut(&mut self) -> &mut W { - &mut self.upstream - } - - /// Consume and return upstream - pub fn into_inner(self) -> W { - self.upstream - } - - /// Check if stream is in poisoned state - pub fn is_poisoned(&self) -> bool { - self.state.is_poisoned() - } - - /// Get current state name (for debugging) - pub fn state_name(&self) -> &'static str { - self.state.state_name() - } - - /// Check if there's a pending record to write + + pub fn get_ref(&self) -> &W { &self.upstream } + pub fn get_mut(&mut self) -> &mut W { &mut self.upstream } + pub fn into_inner(self) -> W { self.upstream } + + pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() } + pub fn state_name(&self) -> &'static str { self.state.state_name() } + pub fn has_pending(&self) -> bool { matches!(&self.state, TlsWriterState::WritingRecord { record, .. } if !record.is_empty()) } - - /// Transition to poisoned state + fn poison(&mut self, error: io::Error) { self.state = TlsWriterState::Poisoned { error: Some(error) }; } - - /// Take error from poisoned state + fn take_poison_error(&mut self) -> io::Error { match &mut self.state { - TlsWriterState::Poisoned { error } => { - error.take().unwrap_or_else(|| { - io::Error::new(ErrorKind::Other, "stream previously poisoned") - }) - } + TlsWriterState::Poisoned { error } => error.take().unwrap_or_else(|| { + io::Error::new(ErrorKind::Other, "stream previously poisoned") + }), _ => io::Error::new(ErrorKind::Other, "stream not poisoned"), } } - - /// Build a TLS Application Data record + fn build_record(data: &[u8]) -> BytesMut { let header = TlsRecordHeader { record_type: TLS_RECORD_APPLICATION, version: TLS_VERSION, length: data.len() as u16, }; - + let mut record = BytesMut::with_capacity(TLS_HEADER_SIZE + data.len()); record.extend_from_slice(&header.to_bytes()); record.extend_from_slice(data); @@ -689,18 +633,13 @@ impl FakeTlsWriter { } } -/// Result of flushing pending record enum FlushResult { - /// All data flushed, returns payload size Complete(usize), - /// Need to wait for upstream Pending, - /// Error occurred Error(io::Error), } impl FakeTlsWriter { - /// Try to flush pending record to upstream (standalone logic) fn poll_flush_record_inner( upstream: &mut W, cx: &mut Context<'_>, @@ -710,22 +649,17 @@ impl FakeTlsWriter { let data = record.pending(); match Pin::new(&mut *upstream).poll_write(cx, data) { Poll::Pending => return FlushResult::Pending, - Poll::Ready(Err(e)) => return FlushResult::Error(e), - Poll::Ready(Ok(0)) => { return FlushResult::Error(Error::new( ErrorKind::WriteZero, - "upstream returned 0 bytes written" + "upstream returned 0 bytes written", )); } - - Poll::Ready(Ok(n)) => { - record.advance(n); - } + Poll::Ready(Ok(n)) => record.advance(n), } } - + FlushResult::Complete(0) } } @@ -737,10 +671,10 @@ impl AsyncWrite for FakeTlsWriter { buf: &[u8], ) -> Poll> { let this = self.get_mut(); - - // Take ownership of state + + // Take ownership of state to avoid borrow conflicts. let state = std::mem::replace(&mut this.state, TlsWriterState::Idle); - + match state { TlsWriterState::Poisoned { error } => { this.state = TlsWriterState::Poisoned { error: None }; @@ -749,9 +683,9 @@ impl AsyncWrite for FakeTlsWriter { }); return Poll::Ready(Err(err)); } - + TlsWriterState::WritingRecord { mut record, payload_size } => { - // Continue flushing existing record + // Finish writing previous record before accepting new bytes. match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) { FlushResult::Pending => { this.state = TlsWriterState::WritingRecord { record, payload_size }; @@ -763,79 +697,76 @@ impl AsyncWrite for FakeTlsWriter { } FlushResult::Complete(_) => { this.state = TlsWriterState::Idle; - // Fall through to handle new write + // continue to accept new buf below } } } - + TlsWriterState::Idle => { this.state = TlsWriterState::Idle; } } - + // Now in Idle state if buf.is_empty() { return Poll::Ready(Ok(0)); } - + // Chunk to maximum TLS payload size let chunk_size = buf.len().min(MAX_TLS_PAYLOAD); let chunk = &buf[..chunk_size]; - - // Build the complete record + + // Build the complete record (header + payload) let record_data = Self::build_record(chunk); - - // Try to write directly first + match Pin::new(&mut this.upstream).poll_write(cx, &record_data) { Poll::Ready(Ok(n)) if n == record_data.len() => { - // Complete record written Poll::Ready(Ok(chunk_size)) } - + Poll::Ready(Ok(n)) => { - // Partial write - buffer the rest + // Partial write of the record: store remainder. let mut write_buffer = WriteBuffer::with_max_size(MAX_PENDING_WRITE); + // record_data length is <= 16389, fits MAX_PENDING_WRITE let _ = write_buffer.extend(&record_data[n..]); - + this.state = TlsWriterState::WritingRecord { record: write_buffer, payload_size: chunk_size, }; - - // We've accepted chunk_size bytes from caller + + // We have accepted chunk_size bytes from caller. Poll::Ready(Ok(chunk_size)) } - + Poll::Ready(Err(e)) => { this.poison(Error::new(e.kind(), e.to_string())); Poll::Ready(Err(e)) } - + Poll::Pending => { - // Buffer the entire record + // Buffer entire record and report success for this chunk. let mut write_buffer = WriteBuffer::with_max_size(MAX_PENDING_WRITE); let _ = write_buffer.extend(&record_data); - + this.state = TlsWriterState::WritingRecord { record: write_buffer, payload_size: chunk_size, }; - - // Wake to try again + + // Wake to retry flushing soon. cx.waker().wake_by_ref(); - - // We've accepted chunk_size bytes from caller + Poll::Ready(Ok(chunk_size)) } } } - + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - - // Take ownership of state + let state = std::mem::replace(&mut this.state, TlsWriterState::Idle); - + match state { TlsWriterState::Poisoned { error } => { this.state = TlsWriterState::Poisoned { error: None }; @@ -844,7 +775,7 @@ impl AsyncWrite for FakeTlsWriter { }); return Poll::Ready(Err(err)); } - + TlsWriterState::WritingRecord { mut record, payload_size } => { match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) { FlushResult::Pending => { @@ -860,64 +791,49 @@ impl AsyncWrite for FakeTlsWriter { } } } - + TlsWriterState::Idle => { this.state = TlsWriterState::Idle; } } - - // Flush upstream + Pin::new(&mut this.upstream).poll_flush(cx) } - + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - - // Take ownership of state + let state = std::mem::replace(&mut this.state, TlsWriterState::Idle); - + match state { - TlsWriterState::WritingRecord { mut record, payload_size } => { - // Try to flush pending (best effort) - match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) { - FlushResult::Pending => { - // Can't complete flush, continue with shutdown anyway - this.state = TlsWriterState::Idle; - } - FlushResult::Error(_) => { - // Ignore errors during shutdown - this.state = TlsWriterState::Idle; - } - FlushResult::Complete(_) => { - this.state = TlsWriterState::Idle; - } - } + TlsWriterState::WritingRecord { mut record, payload_size: _ } => { + // Best-effort flush (do not block shutdown forever). + let _ = Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record); + this.state = TlsWriterState::Idle; } _ => { this.state = TlsWriterState::Idle; } } - - // Shutdown upstream + Pin::new(&mut this.upstream).poll_shutdown(cx) } } impl FakeTlsWriter { - /// Write all data wrapped in TLS records (async method) + /// Write all data wrapped in TLS records. /// - /// This convenience method handles chunking large data into - /// multiple TLS records automatically. + /// Convenience method that chunks into <= 16384 records. pub async fn write_all_tls(&mut self, data: &[u8]) -> Result<()> { let mut written = 0; while written < data.len() { let chunk_size = (data.len() - written).min(MAX_TLS_PAYLOAD); let chunk = &data[written..written + chunk_size]; - + AsyncWriteExt::write_all(self, chunk).await?; written += chunk_size; } - + self.flush().await } }