diff --git a/src/config/mod.rs b/src/config/mod.rs index d8a4806..bbe3f61 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -163,9 +163,12 @@ fn default_port() -> u16 { 443 } fn default_tls_domain() -> String { "www.google.com".to_string() } fn default_mask_port() -> u16 { 443 } fn default_replay_check_len() -> usize { 65536 } -fn default_handshake_timeout() -> u64 { 10 } +// CHANGED: Increased handshake timeout for bad mobile networks +fn default_handshake_timeout() -> u64 { 15 } fn default_connect_timeout() -> u64 { 10 } -fn default_keepalive() -> u64 { 600 } +// CHANGED: Reduced keepalive from 600s to 60s. +// Mobile NATs often drop idle connections after 60-120s. +fn default_keepalive() -> u64 { 60 } fn default_ack_timeout() -> u64 { 300 } fn default_listen_addr() -> String { "0.0.0.0".to_string() } fn default_fake_cert_len() -> usize { 2048 } diff --git a/src/main.rs b/src/main.rs index 45bd872..fde2a59 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,7 +20,7 @@ mod util; use crate::config::ProxyConfig; use crate::proxy::ClientHandler; -use crate::stats::Stats; +use crate::stats::{Stats, ReplayChecker}; use crate::transport::{create_listener, ListenOptions, UpstreamManager}; use crate::util::ip::detect_ip; @@ -55,6 +55,9 @@ async fn main() -> Result<(), Box> { let config = Arc::new(config); let stats = Arc::new(Stats::new()); + // CHANGED: Initialize global ReplayChecker here instead of per-connection + let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len)); + // Initialize Upstream Manager let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone())); @@ -145,13 +148,11 @@ async fn main() -> Result<(), Box> { } // Accept loop - // For simplicity in this slice, we just spawn a task for each listener - // In a real high-perf scenario, we might want a more complex accept loop - for listener in listeners { let config = config.clone(); let stats = stats.clone(); let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); tokio::spawn(async move { loop { @@ -160,6 +161,7 @@ async fn main() -> Result<(), Box> { let config = config.clone(); let stats = stats.clone(); let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); tokio::spawn(async move { if let Err(e) = ClientHandler::new( @@ -167,7 +169,8 @@ async fn main() -> Result<(), Box> { peer_addr, config, stats, - upstream_manager + upstream_manager, + replay_checker // Pass global checker ).run().await { // Log only relevant errors // debug!("Connection error: {}", e); diff --git a/src/protocol/constants.rs b/src/protocol/constants.rs index 17857e4..d09473e 100644 --- a/src/protocol/constants.rs +++ b/src/protocol/constants.rs @@ -167,7 +167,10 @@ pub const DEFAULT_ACK_TIMEOUT_SECS: u64 = 300; // ============= Buffer Sizes ============= /// Default buffer size -pub const DEFAULT_BUFFER_SIZE: usize = 65536; +/// CHANGED: Reduced from 64KB to 16KB to match TLS record size and align with +/// the new buffering strategy for better iOS upload performance. +pub const DEFAULT_BUFFER_SIZE: usize = 16384; + /// Small buffer size for bad client handling pub const SMALL_BUFFER_SIZE: usize = 8192; diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 2cb4d9d..11e3a81 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -45,11 +45,10 @@ impl ClientHandler { config: Arc, stats: Arc, upstream_manager: Arc, + replay_checker: Arc, // CHANGED: Accept global checker ) -> RunningClientHandler { - // Note: ReplayChecker should be shared globally for proper replay protection - // Creating it per-connection disables replay protection across connections - // TODO: Pass Arc from main.rs - let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len)); + // CHANGED: Removed local creation of ReplayChecker. + // It is now passed from main.rs to ensure global replay protection. RunningClientHandler { stream, diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index c531a1a..f90b247 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -1,13 +1,21 @@ //! Bidirectional Relay use std::sync::Arc; +use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; -use tracing::{debug, trace, warn}; +use tokio::time::Instant; +use tracing::{debug, trace, warn, info}; use crate::error::Result; use crate::stats::Stats; use std::sync::atomic::{AtomicU64, Ordering}; -const BUFFER_SIZE: usize = 65536; +// CHANGED: Reduced from 128KB to 16KB to match TLS record size and prevent bufferbloat. +// This is critical for iOS clients to maintain proper TCP flow control during uploads. +const BUFFER_SIZE: usize = 16384; + +// Activity timeout for iOS compatibility (30 minutes) +// iOS does not support TCP_USER_TIMEOUT, so we implement application-level timeout +const ACTIVITY_TIMEOUT_SECS: u64 = 1800; /// Relay data bidirectionally between client and server pub async fn relay_bidirectional( @@ -36,15 +44,40 @@ where let c2s_bytes_clone = Arc::clone(&c2s_bytes); let s2c_bytes_clone = Arc::clone(&s2c_bytes); - // Client -> Server task + // Activity timeout for iOS compatibility + let activity_timeout = Duration::from_secs(ACTIVITY_TIMEOUT_SECS); + + // Client -> Server task with activity timeout let c2s = tokio::spawn(async move { let mut buf = vec![0u8; BUFFER_SIZE]; let mut total_bytes = 0u64; let mut msg_count = 0u64; + let mut last_activity = Instant::now(); + let mut last_log = Instant::now(); loop { - match client_reader.read(&mut buf).await { - Ok(0) => { + // Read with timeout to prevent infinite hang on iOS + let read_result = tokio::time::timeout( + activity_timeout, + client_reader.read(&mut buf) + ).await; + + match read_result { + // Timeout - no activity for too long + Err(_) => { + warn!( + user = %user_c2s, + total_bytes = total_bytes, + msgs = msg_count, + idle_secs = last_activity.elapsed().as_secs(), + "Activity timeout (C->S) - no data received" + ); + let _ = server_writer.shutdown().await; + break; + } + + // Read successful + Ok(Ok(0)) => { debug!( user = %user_c2s, total_bytes = total_bytes, @@ -54,9 +87,11 @@ where let _ = server_writer.shutdown().await; break; } - Ok(n) => { + + Ok(Ok(n)) => { total_bytes += n as u64; msg_count += 1; + last_activity = Instant::now(); c2s_bytes_clone.store(total_bytes, Ordering::Relaxed); stats_c2s.add_user_octets_from(&user_c2s, n as u64); @@ -70,6 +105,19 @@ where "C->S data" ); + // Log activity every 10 seconds for large transfers + if last_log.elapsed() > Duration::from_secs(10) { + let rate = total_bytes as f64 / last_log.elapsed().as_secs_f64(); + info!( + user = %user_c2s, + total_bytes = total_bytes, + msgs = msg_count, + rate_kbps = (rate / 1024.0) as u64, + "C->S transfer in progress" + ); + last_log = Instant::now(); + } + if let Err(e) = server_writer.write_all(&buf[..n]).await { debug!(user = %user_c2s, error = %e, "Failed to write to server"); break; @@ -79,7 +127,8 @@ where break; } } - Err(e) => { + + Ok(Err(e)) => { debug!(user = %user_c2s, error = %e, total_bytes = total_bytes, "Client read error"); break; } @@ -87,15 +136,37 @@ where } }); - // Server -> Client task + // Server -> Client task with activity timeout let s2c = tokio::spawn(async move { let mut buf = vec![0u8; BUFFER_SIZE]; let mut total_bytes = 0u64; let mut msg_count = 0u64; + let mut last_activity = Instant::now(); + let mut last_log = Instant::now(); loop { - match server_reader.read(&mut buf).await { - Ok(0) => { + // Read with timeout to prevent infinite hang on iOS + let read_result = tokio::time::timeout( + activity_timeout, + server_reader.read(&mut buf) + ).await; + + match read_result { + // Timeout - no activity for too long + Err(_) => { + warn!( + user = %user_s2c, + total_bytes = total_bytes, + msgs = msg_count, + idle_secs = last_activity.elapsed().as_secs(), + "Activity timeout (S->C) - no data received" + ); + let _ = client_writer.shutdown().await; + break; + } + + // Read successful + Ok(Ok(0)) => { debug!( user = %user_s2c, total_bytes = total_bytes, @@ -105,9 +176,11 @@ where let _ = client_writer.shutdown().await; break; } - Ok(n) => { + + Ok(Ok(n)) => { total_bytes += n as u64; msg_count += 1; + last_activity = Instant::now(); s2c_bytes_clone.store(total_bytes, Ordering::Relaxed); stats_s2c.add_user_octets_to(&user_s2c, n as u64); @@ -121,6 +194,19 @@ where "S->C data" ); + // Log activity every 10 seconds for large transfers + if last_log.elapsed() > Duration::from_secs(10) { + let rate = total_bytes as f64 / last_log.elapsed().as_secs_f64(); + info!( + user = %user_s2c, + total_bytes = total_bytes, + msgs = msg_count, + rate_kbps = (rate / 1024.0) as u64, + "S->C transfer in progress" + ); + last_log = Instant::now(); + } + if let Err(e) = client_writer.write_all(&buf[..n]).await { debug!(user = %user_s2c, error = %e, "Failed to write to client"); break; @@ -130,7 +216,8 @@ where break; } } - Err(e) => { + + Ok(Err(e)) => { debug!(user = %user_s2c, error = %e, total_bytes = total_bytes, "Server read error"); break; } diff --git a/src/stream/buffer_pool.rs b/src/stream/buffer_pool.rs index 55be736..ac4a5f9 100644 --- a/src/stream/buffer_pool.rs +++ b/src/stream/buffer_pool.rs @@ -11,8 +11,9 @@ use std::sync::Arc; // ============= Configuration ============= -/// Default buffer size (64KB - good for MTProto) -pub const DEFAULT_BUFFER_SIZE: usize = 64 * 1024; +/// Default buffer size +/// CHANGED: Reduced from 64KB to 16KB to match TLS record size and prevent bufferbloat. +pub const DEFAULT_BUFFER_SIZE: usize = 16 * 1024; /// Default maximum number of pooled buffers pub const DEFAULT_MAX_BUFFERS: usize = 1024; diff --git a/src/stream/crypto_stream.rs b/src/stream/crypto_stream.rs index 5ee93a7..b06bc7c 100644 --- a/src/stream/crypto_stream.rs +++ b/src/stream/crypto_stream.rs @@ -6,57 +6,87 @@ //! Key design principles: //! - Explicit state machines for all async operations //! - Never lose data on partial reads/writes -//! - Honest reporting of bytes written +//! - Honest reporting of bytes written (AsyncWrite contract) //! - Bounded internal buffers with backpressure +//! +//! AES-CTR is a stream cipher: the keystream position must advance exactly by the +//! number of plaintext bytes that are *accepted* (written or buffered). +//! +//! This implementation guarantees: +//! - CTR state never "drifts" +//! - never accept plaintext unless we can guarantee that all corresponding ciphertext +//! is either written to upstream or stored in our pending buffer +//! - when upstream is pending -> ciphertext is buffered/bounded and backpressure is applied +//! +//! ======================= +//! Writer state machine +//! ======================= +//! +//! ┌──────────┐ write buf ┌──────────┐ +//! │ Idle │ ---------------> │ Flushing │ +//! │ │ <--------------- │ │ +//! └──────────┘ drained └──────────┘ +//! │ │ +//! │ errors │ +//! ▼ ▼ +//! ┌────────────────────────────────────────┐ +//! │ Poisoned │ +//! └────────────────────────────────────────┘ +//! +//! Backpressure +//! - pending ciphertext buffer is bounded (MAX_PENDING_WRITE) +//! - pending is full and upstream is pending +//! -> poll_write returns Poll::Pending +//! -> do not accept any plaintext +//! +//! Performance +//! - fast path when pending is empty: encrypt into scratch and try upstream +//! - if upstream Pending/partial => move remainder into pending without re-encrypting +//! - when upstream is Pending but pending still has room: accept `to_accept` bytes and +//! encrypt+append ciphertext directly into pending (in-place encryption of appended range) -use bytes::{Bytes, BytesMut, BufMut}; -use std::io::{self, Error, ErrorKind, Result}; +use bytes::{Bytes, BytesMut}; +use std::io::{self, ErrorKind, Result}; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tracing::{debug, trace, warn}; use crate::crypto::AesCtr; -use crate::error::StreamError; -use super::state::{StreamState, ReadBuffer, WriteBuffer, YieldBuffer}; +use super::state::{StreamState, YieldBuffer}; // ============= Constants ============= -/// Maximum size for pending write buffer (256KB) -const MAX_PENDING_WRITE: usize = 256 * 1024; +/// Maximum size for pending ciphertext buffer (bounded backpressure). +/// 512 KiB tends to work well for mobile networks and avoids huge latency spikes. +const MAX_PENDING_WRITE: usize = 524_288; -/// Default read buffer capacity +/// Default read buffer capacity (reader mostly decrypts in-place into caller buffer). const DEFAULT_READ_CAPACITY: usize = 16 * 1024; // ============= CryptoReader State ============= -/// State machine states for CryptoReader #[derive(Debug)] enum CryptoReaderState { /// Ready to read new data Idle, - + /// Have decrypted data ready to yield to caller - Yielding { - /// Buffer containing decrypted data - buffer: YieldBuffer, - }, - + Yielding { buffer: YieldBuffer }, + /// Stream encountered an error and cannot be used - Poisoned { - /// The error that caused poisoning (taken on first access) - error: Option, - }, + Poisoned { error: Option }, } impl StreamState for CryptoReaderState { fn is_terminal(&self) -> bool { matches!(self, Self::Poisoned { .. }) } - + fn is_poisoned(&self) -> bool { matches!(self, Self::Poisoned { .. }) } - + fn state_name(&self) -> &'static str { match self { Self::Idle => "Idle", @@ -68,7 +98,7 @@ impl StreamState for CryptoReaderState { // ============= CryptoReader ============= -/// Reader that decrypts data using AES-CTR with proper state machine +/// Reader that decrypts data using AES-CTR with proper state machine. /// /// This reader handles partial reads correctly by maintaining internal state /// and never losing any data that has been read from upstream. @@ -78,25 +108,24 @@ impl StreamState for CryptoReaderState { /// ┌──────────┐ read ┌──────────┐ /// │ Idle │ ------------> │ Yielding │ /// │ │ <------------ │ │ -/// └──────────┘ drained └──────────┘ +/// └──────────┘ drained └──────────┘ /// │ │ /// │ errors │ +/// ▼ ▼ /// ┌──────────────────────────────────────┐ /// │ Poisoned │ /// └──────────────────────────────────────┘ pub struct CryptoReader { - /// Upstream reader upstream: R, - /// AES-CTR decryptor decryptor: AesCtr, - /// Current state state: CryptoReaderState, - /// Internal read buffer for upstream reads + + /// Reserved for future coalescing optimizations. + #[allow(dead_code)] read_buf: BytesMut, } impl CryptoReader { - /// Create new crypto reader pub fn new(upstream: R, decryptor: AesCtr) -> Self { Self { upstream, @@ -105,45 +134,36 @@ impl CryptoReader { read_buf: BytesMut::with_capacity(DEFAULT_READ_CAPACITY), } } - - /// Get reference to upstream + pub fn get_ref(&self) -> &R { &self.upstream } - - /// Get mutable reference to upstream + pub fn get_mut(&mut self) -> &mut R { &mut self.upstream } - - /// Consume and return upstream + pub fn into_inner(self) -> R { self.upstream } - - /// Check if stream is in poisoned state + pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() } - - /// Get current state name (for debugging) + pub fn state_name(&self) -> &'static str { self.state.state_name() } - - /// Transition to poisoned state + fn poison(&mut self, error: io::Error) { self.state = CryptoReaderState::Poisoned { error: Some(error) }; } - - /// Take error from poisoned state + fn take_poison_error(&mut self) -> io::Error { match &mut self.state { - CryptoReaderState::Poisoned { error } => { - error.take().unwrap_or_else(|| { - io::Error::new(ErrorKind::Other, "stream previously poisoned") - }) - } + CryptoReaderState::Poisoned { error } => error.take().unwrap_or_else(|| { + io::Error::new(ErrorKind::Other, "stream previously poisoned") + }), _ => io::Error::new(ErrorKind::Other, "stream not poisoned"), } } @@ -156,67 +176,61 @@ impl AsyncRead for CryptoReader { buf: &mut ReadBuf<'_>, ) -> Poll> { let this = self.get_mut(); - + loop { match &mut this.state { - // Poisoned state - return error CryptoReaderState::Poisoned { .. } => { let err = this.take_poison_error(); return Poll::Ready(Err(err)); } - - // Have buffered data to yield + CryptoReaderState::Yielding { buffer } => { if buf.remaining() == 0 { return Poll::Ready(Ok(())); } - - // Copy as much as possible to output + let to_copy = buffer.remaining().min(buf.remaining()); let dst = buf.initialize_unfilled_to(to_copy); let copied = buffer.copy_to(dst); buf.advance(copied); - - // If buffer is drained, transition to Idle + if buffer.is_empty() { this.state = CryptoReaderState::Idle; } - + return Poll::Ready(Ok(())); } - - // Ready to read from upstream + CryptoReaderState::Idle => { - // If caller's buffer is empty, nothing to do if buf.remaining() == 0 { return Poll::Ready(Ok(())); } - - // Try to read directly into caller's buffer for zero-copy path - // We need to be careful: read into unfilled portion, then decrypt - let before_len = buf.filled().len(); - + + // Read directly into caller buffer, decrypt in-place for the bytes read. + let before = buf.filled().len(); + match Pin::new(&mut this.upstream).poll_read(cx, buf) { Poll::Pending => return Poll::Pending, - + Poll::Ready(Err(e)) => { this.poison(io::Error::new(e.kind(), e.to_string())); return Poll::Ready(Err(e)); } - + Poll::Ready(Ok(())) => { - let after_len = buf.filled().len(); - let bytes_read = after_len - before_len; - + let after = buf.filled().len(); + let bytes_read = after - before; + if bytes_read == 0 { // EOF return Poll::Ready(Ok(())); } - - // Decrypt the newly read data in-place + let filled = buf.filled_mut(); - this.decryptor.apply(&mut filled[before_len..after_len]); - + this.decryptor.apply(&mut filled[before..after]); + + trace!(bytes_read, state = this.state_name(), "CryptoReader decrypted chunk"); + return Poll::Ready(Ok(())); } } @@ -227,115 +241,197 @@ impl AsyncRead for CryptoReader { } impl CryptoReader { - /// Read and decrypt exactly n bytes - /// - /// This is a convenience method that accumulates data until - /// exactly n bytes are available. + /// Read and decrypt exactly n bytes. pub async fn read_exact_decrypt(&mut self, n: usize) -> Result { use tokio::io::AsyncReadExt; - + if self.is_poisoned() { return Err(self.take_poison_error()); } - + let mut result = BytesMut::with_capacity(n); - - // First drain any buffered data from Yielding state + + // Drain Yielding buffer if present (rare, kept for completeness) if let CryptoReaderState::Yielding { buffer } = &mut self.state { let to_take = buffer.remaining().min(n); let mut temp = vec![0u8; to_take]; buffer.copy_to(&mut temp); result.extend_from_slice(&temp); - + if buffer.is_empty() { self.state = CryptoReaderState::Idle; } } - - // Read remaining from upstream + while result.len() < n { let mut temp = vec![0u8; n - result.len()]; let read = self.read(&mut temp).await?; - + if read == 0 { return Err(io::Error::new( ErrorKind::UnexpectedEof, - format!("expected {} bytes, got {}", n, result.len()) + format!("expected {} bytes, got {}", n, result.len()), )); } - + result.extend_from_slice(&temp[..read]); } - + Ok(result.freeze()) } - - /// Read into internal buffer and return decrypted bytes - /// - /// Useful when you need the data as Bytes rather than copying to a slice. + + /// Read up to max_size bytes, returning decrypted bytes as Bytes. pub async fn read_decrypt(&mut self, max_size: usize) -> Result { use tokio::io::AsyncReadExt; - + if self.is_poisoned() { return Err(self.take_poison_error()); } - - // First check if we have buffered data + if let CryptoReaderState::Yielding { buffer } = &mut self.state { let to_take = buffer.remaining().min(max_size); let mut temp = vec![0u8; to_take]; buffer.copy_to(&mut temp); - + if buffer.is_empty() { self.state = CryptoReaderState::Idle; } - + return Ok(Bytes::from(temp)); } - - // Read from upstream + let mut temp = vec![0u8; max_size]; let read = self.read(&mut temp).await?; - + if read == 0 { return Ok(Bytes::new()); } - + temp.truncate(read); Ok(Bytes::from(temp)) } } +// ============= Pending Ciphertext ============= + +/// Pending ciphertext buffer with explicit position and strict max size. +/// +/// - append plaintext then encrypt appended range in-place - one-touch copy, no extra Vec +/// - move ciphertext from scratch into pending without copying +/// - explicit compaction behavior for long-lived connections +#[derive(Debug)] +struct PendingCiphertext { + buf: BytesMut, + pos: usize, + max_len: usize, +} + +impl PendingCiphertext { + fn new(max_len: usize) -> Self { + Self { + buf: BytesMut::with_capacity(16 * 1024), + pos: 0, + max_len, + } + } + + fn pending_len(&self) -> usize { + self.buf.len().saturating_sub(self.pos) + } + + fn is_empty(&self) -> bool { + self.pending_len() == 0 + } + + fn pending_slice(&self) -> &[u8] { + &self.buf[self.pos..] + } + + fn remaining_capacity(&self) -> usize { + self.max_len.saturating_sub(self.buf.len()) + } + + fn advance(&mut self, n: usize) { + self.pos = (self.pos + n).min(self.buf.len()); + + if self.pos == self.buf.len() { + self.buf.clear(); + self.pos = 0; + return; + } + + // Compact when a large prefix was consumed. + if self.pos >= 32 * 1024 { + let _ = self.buf.split_to(self.pos); + self.pos = 0; + } + } + + /// Replace the entire pending ciphertext by moving `src` in (swap, no copy). + /// + /// Precondition: src.len() <= max_len. + fn replace_with(&mut self, mut src: BytesMut) { + debug_assert!(src.len() <= self.max_len); + + self.buf.clear(); + self.pos = 0; + + // Swap: keep allocations hot and avoid copying bytes. + std::mem::swap(&mut self.buf, &mut src); + } + + /// Append plaintext and encrypt appended range in-place. + /// + /// This is the high-throughput buffering path: + /// - copy plaintext into pending buffer + /// - encrypt only the newly appended bytes + /// + /// CTR state advances by exactly plaintext.len(). + fn push_encrypted(&mut self, encryptor: &mut AesCtr, plaintext: &[u8]) -> Result<()> { + if plaintext.is_empty() { + return Ok(()); + } + + if plaintext.len() > self.remaining_capacity() { + return Err(io::Error::new( + ErrorKind::WouldBlock, + "pending ciphertext buffer is full", + )); + } + + let start = self.buf.len(); + self.buf.reserve(plaintext.len()); + self.buf.extend_from_slice(plaintext); + + encryptor.apply(&mut self.buf[start..]); + + Ok(()) + } +} + // ============= CryptoWriter State ============= -/// State machine states for CryptoWriter #[derive(Debug)] enum CryptoWriterState { - /// Ready to accept new data + /// No pending ciphertext buffered. Idle, - - /// Have pending encrypted data to flush - Flushing { - /// Buffer of encrypted data waiting to be written - pending: WriteBuffer, - }, - + + /// There is pending ciphertext to flush. + Flushing { pending: PendingCiphertext }, + /// Stream encountered an error and cannot be used - Poisoned { - /// The error that caused poisoning - error: Option, - }, + Poisoned { error: Option }, } impl StreamState for CryptoWriterState { fn is_terminal(&self) -> bool { matches!(self, Self::Poisoned { .. }) } - + fn is_poisoned(&self) -> bool { matches!(self, Self::Poisoned { .. }) } - + fn state_name(&self) -> &'static str { match self { Self::Idle => "Idle", @@ -347,154 +443,178 @@ impl StreamState for CryptoWriterState { // ============= CryptoWriter ============= -/// Writer that encrypts data using AES-CTR with proper state machine +/// Writer that encrypts data using AES-CTR with correct async semantics. /// -/// This writer handles partial writes correctly by: -/// - Maintaining internal state for pending data -/// - Returning honest byte counts (only what's actually written or safely buffered) -/// - Implementing backpressure when internal buffer is full -/// -/// # State Machine -/// -/// ┌──────────┐ write ┌──────────┐ -/// │ Idle │ ----------> │ Flushing │ -/// │ │ <---------- │ │ -/// └──────────┘ flushed └──────────┘ -/// │ │ -/// │ errors │ -/// ┌───────────────────────────────────┐ -/// │ Poisoned │ -/// └───────────────────────────────────┘ -/// -/// # Backpressure -/// -/// When the internal pending buffer exceeds `MAX_PENDING_WRITE`, the writer -/// will return `Poll::Pending` until some data has been flushed to upstream. +/// - CTR state advances exactly by the number of bytes we report as written +/// - If upstream blocks, ciphertext is buffered/bounded +/// - Backpressure is applied when buffer is full pub struct CryptoWriter { - /// Upstream writer upstream: W, - /// AES-CTR encryptor encryptor: AesCtr, - /// Current state state: CryptoWriterState, + + /// Scratch ciphertext for fast "write-through" path. + /// + /// Flow: + /// - encrypt plaintext into scratch + /// - try upstream write + /// - if Pending/partial: move remainder into pending without re-encrypting + scratch: BytesMut, } impl CryptoWriter { - /// Create new crypto writer pub fn new(upstream: W, encryptor: AesCtr) -> Self { Self { upstream, encryptor, state: CryptoWriterState::Idle, + scratch: BytesMut::with_capacity(16 * 1024), } } - - /// Get reference to upstream + pub fn get_ref(&self) -> &W { &self.upstream } - - /// Get mutable reference to upstream + pub fn get_mut(&mut self) -> &mut W { &mut self.upstream } - - /// Consume and return upstream + pub fn into_inner(self) -> W { self.upstream } - - /// Check if stream is in poisoned state + pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() } - - /// Get current state name (for debugging) + pub fn state_name(&self) -> &'static str { self.state.state_name() } - - /// Check if there's pending data to flush + pub fn has_pending(&self) -> bool { - matches!(&self.state, CryptoWriterState::Flushing { pending } if !pending.is_empty()) + matches!(self.state, CryptoWriterState::Flushing { .. }) } - - /// Get pending bytes count + pub fn pending_len(&self) -> usize { match &self.state { - CryptoWriterState::Flushing { pending } => pending.len(), + CryptoWriterState::Flushing { pending } => pending.pending_len(), _ => 0, } } - - /// Transition to poisoned state + fn poison(&mut self, error: io::Error) { self.state = CryptoWriterState::Poisoned { error: Some(error) }; } - - /// Take error from poisoned state + fn take_poison_error(&mut self) -> io::Error { match &mut self.state { - CryptoWriterState::Poisoned { error } => { - error.take().unwrap_or_else(|| { - io::Error::new(ErrorKind::Other, "stream previously poisoned") - }) - } + CryptoWriterState::Poisoned { error } => error.take().unwrap_or_else(|| { + io::Error::new(ErrorKind::Other, "stream previously poisoned") + }), _ => io::Error::new(ErrorKind::Other, "stream not poisoned"), } } + + /// Ensure we are in Flushing state and return mutable pending buffer. + fn ensure_pending<'a>(state: &'a mut CryptoWriterState) -> &'a mut PendingCiphertext { + if matches!(state, CryptoWriterState::Idle) { + *state = CryptoWriterState::Flushing { + pending: PendingCiphertext::new(MAX_PENDING_WRITE), + }; + } + + match state { + CryptoWriterState::Flushing { pending } => pending, + _ => unreachable!("ensure_pending guarantees Flushing state"), + } + } + + /// Select how many plaintext bytes can be accepted in buffering path + /// + /// Requirement: worst case - upstream pending, must buffer all ciphertext + /// for the accepted bytes + fn select_to_accept_for_buffering(state: &CryptoWriterState, buf_len: usize) -> usize { + if buf_len == 0 { + return 0; + } + + match state { + CryptoWriterState::Flushing { pending } => buf_len.min(pending.remaining_capacity()), + CryptoWriterState::Idle => buf_len.min(MAX_PENDING_WRITE), + CryptoWriterState::Poisoned { .. } => 0, + } + } + + /// Encrypt plaintext into scratch (CTR advances by plaintext.len()). + fn encrypt_into_scratch(encryptor: &mut AesCtr, scratch: &mut BytesMut, plaintext: &[u8]) { + scratch.clear(); + scratch.reserve(plaintext.len()); + scratch.extend_from_slice(plaintext); + encryptor.apply(&mut scratch[..]); + } } impl CryptoWriter { - /// Try to flush pending data to upstream + /// Flush as much pending ciphertext as possible /// - /// Returns: - /// - `Poll::Ready(Ok(true))` if all pending data was flushed - /// - `Poll::Ready(Ok(false))` if some data remains - /// - `Poll::Pending` if upstream would block - /// - `Poll::Ready(Err(_))` on error - fn poll_flush_pending(&mut self, cx: &mut Context<'_>) -> Poll> { + /// Returns + /// - Ready(Ok(())) if all pending is flushed or was none + /// - Pending if upstream would block + /// - Ready(Err(_)) on error + fn poll_flush_pending(&mut self, cx: &mut Context<'_>) -> Poll> { loop { match &mut self.state { - CryptoWriterState::Idle => { - return Poll::Ready(Ok(true)); - } - CryptoWriterState::Poisoned { .. } => { let err = self.take_poison_error(); return Poll::Ready(Err(err)); } - + + CryptoWriterState::Idle => return Poll::Ready(Ok(())), + CryptoWriterState::Flushing { pending } => { if pending.is_empty() { self.state = CryptoWriterState::Idle; - return Poll::Ready(Ok(true)); + return Poll::Ready(Ok(())); } - - let data = pending.pending(); + + let data = pending.pending_slice(); + match Pin::new(&mut self.upstream).poll_write(cx, data) { - Poll::Pending => return Poll::Pending, - + Poll::Pending => { + trace!( + pending_len = pending.pending_len(), + pending_cap = pending.remaining_capacity(), + "CryptoWriter: upstream Pending while flushing pending ciphertext" + ); + return Poll::Pending; + } + Poll::Ready(Err(e)) => { self.poison(io::Error::new(e.kind(), e.to_string())); return Poll::Ready(Err(e)); } - + Poll::Ready(Ok(0)) => { let err = io::Error::new( ErrorKind::WriteZero, - "upstream returned 0 bytes written" + "upstream returned 0 bytes written", ); - self.poison(err.into()); - return Poll::Ready(Err(io::Error::new( - ErrorKind::WriteZero, - "upstream returned 0 bytes written" - ))); + self.poison(io::Error::new(err.kind(), err.to_string())); + return Poll::Ready(Err(err)); } - + Poll::Ready(Ok(n)) => { pending.advance(n); - // Continue loop to check if fully flushed + + trace!( + flushed = n, + pending_left = pending.pending_len(), + "CryptoWriter: flushed pending ciphertext" + ); + + // continue loop to flush more + continue; } } } @@ -510,153 +630,176 @@ impl AsyncWrite for CryptoWriter { buf: &[u8], ) -> Poll> { let this = self.get_mut(); - - // Check for poisoned state - if let CryptoWriterState::Poisoned { .. } = &this.state { + + // Poisoned? + if matches!(this.state, CryptoWriterState::Poisoned { .. }) { let err = this.take_poison_error(); return Poll::Ready(Err(err)); } - - // Empty write is always successful + + // Empty write is always OK if buf.is_empty() { return Poll::Ready(Ok(0)); } - - // First, try to flush any pending data - match this.poll_flush_pending(cx) { - Poll::Pending => { - // Check backpressure - if this.pending_len() >= MAX_PENDING_WRITE { - // Too much pending, must wait - return Poll::Pending; + + // 1) If we have pending ciphertext, prioritize flushing it + // If upstream pending + // -> still accept some plaintext ONLY if we can buffer + // all ciphertext for the accepted portion - bounded + if matches!(this.state, CryptoWriterState::Flushing { .. }) { + match this.poll_flush_pending(cx) { + Poll::Ready(Ok(())) => { + // pending drained -> proceed + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => { + // Upstream blocked. Apply ideal backpressure + // - accept up to remaining pending capacity + // - if no capacity -> pending + let to_accept = + Self::select_to_accept_for_buffering(&this.state, buf.len()); + + if to_accept == 0 { + trace!( + buf_len = buf.len(), + pending_len = this.pending_len(), + "CryptoWriter backpressure: pending full and upstream Pending -> Pending" + ); + return Poll::Pending; + } + + let plaintext = &buf[..to_accept]; + + // Disjoint borrows: borrow encryptor and state separately via a match + let encryptor = &mut this.encryptor; + let pending = Self::ensure_pending(&mut this.state); + + // Should not WouldBlock because to_accept <= remaining_capacity + if let Err(e) = pending.push_encrypted(encryptor, plaintext) { + if e.kind() == ErrorKind::WouldBlock { + return Poll::Pending; + } + return Poll::Ready(Err(e)); + } + + trace!( + accepted = to_accept, + pending_len = pending.pending_len(), + pending_cap = pending.remaining_capacity(), + "CryptoWriter: upstream Pending, buffered ciphertext (accepted plaintext)" + ); + + return Poll::Ready(Ok(to_accept)); } - // Can buffer more, continue below - } - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Ready(Ok(_)) => { - // Flushed (possibly partially), continue } } - - // Encrypt the data - let mut encrypted = buf.to_vec(); - this.encryptor.apply(&mut encrypted); - - // Try to write directly to upstream first - match Pin::new(&mut this.upstream).poll_write(cx, &encrypted) { - Poll::Ready(Ok(n)) if n == encrypted.len() => { - // All data written directly - Poll::Ready(Ok(buf.len())) + + // 2) Fast path: pending empty -> write-through + debug_assert!(matches!(this.state, CryptoWriterState::Idle)); + + // Worst-case buffering requirement + // - If upstream becomes pending -> buffer full ciphertext for accepted bytes + // -> accept at most MAX_PENDING_WRITE per poll_write call + let to_accept = buf.len().min(MAX_PENDING_WRITE); + let plaintext = &buf[..to_accept]; + + Self::encrypt_into_scratch(&mut this.encryptor, &mut this.scratch, plaintext); + + match Pin::new(&mut this.upstream).poll_write(cx, &this.scratch) { + Poll::Pending => { + // Upstream blocked: buffer FULL ciphertext for accepted bytes. + // Move scratch into pending without copying. + let ciphertext = std::mem::take(&mut this.scratch); + + let pending = Self::ensure_pending(&mut this.state); + pending.replace_with(ciphertext); + + trace!( + accepted = to_accept, + pending_len = pending.pending_len(), + "CryptoWriter: write-through got Pending, buffered full ciphertext" + ); + + Poll::Ready(Ok(to_accept)) } - - Poll::Ready(Ok(n)) => { - // Partial write - buffer the rest - let remaining = &encrypted[n..]; - - // Ensure we're in Flushing state - let pending = match &mut this.state { - CryptoWriterState::Flushing { pending } => pending, - CryptoWriterState::Idle => { - this.state = CryptoWriterState::Flushing { - pending: WriteBuffer::with_max_size(MAX_PENDING_WRITE), - }; - match &mut this.state { - CryptoWriterState::Flushing { pending } => pending, - _ => unreachable!(), - } - } - CryptoWriterState::Poisoned { .. } => unreachable!(), - }; - - // Try to buffer remaining - if pending.remaining_capacity() >= remaining.len() { - pending.extend(remaining).expect("capacity checked"); - Poll::Ready(Ok(buf.len())) - } else { - // Not enough buffer space - report what we could write - // The caller will need to retry with the rest - let bytes_accepted = n + pending.remaining_capacity(); - if bytes_accepted > n { - let can_buffer = &encrypted[n..bytes_accepted]; - pending.extend(can_buffer).expect("capacity checked"); - } - Poll::Ready(Ok(bytes_accepted.min(buf.len()))) - } - } - + Poll::Ready(Err(e)) => { this.poison(io::Error::new(e.kind(), e.to_string())); Poll::Ready(Err(e)) } - - Poll::Pending => { - // Upstream would block - buffer the encrypted data - let pending = match &mut this.state { - CryptoWriterState::Flushing { pending } => pending, - CryptoWriterState::Idle => { - this.state = CryptoWriterState::Flushing { - pending: WriteBuffer::with_max_size(MAX_PENDING_WRITE), - }; - match &mut this.state { - CryptoWriterState::Flushing { pending } => pending, - _ => unreachable!(), - } - } - CryptoWriterState::Poisoned { .. } => unreachable!(), - }; - - // Check if we can buffer all - if pending.remaining_capacity() >= encrypted.len() { - pending.extend(&encrypted).expect("capacity checked"); - // Wake up to try flushing later - cx.waker().wake_by_ref(); - Poll::Ready(Ok(buf.len())) - } else if pending.remaining_capacity() > 0 { - // Partial buffer - let can_buffer = pending.remaining_capacity(); - pending.extend(&encrypted[..can_buffer]).expect("capacity checked"); - cx.waker().wake_by_ref(); - Poll::Ready(Ok(can_buffer)) - } else { - // No buffer space - backpressure - Poll::Pending + + Poll::Ready(Ok(0)) => { + let err = io::Error::new(ErrorKind::WriteZero, "upstream returned 0 bytes written"); + this.poison(io::Error::new(err.kind(), err.to_string())); + Poll::Ready(Err(err)) + } + + Poll::Ready(Ok(n)) => { + if n == this.scratch.len() { + trace!( + accepted = to_accept, + ciphertext_len = this.scratch.len(), + "CryptoWriter: write-through wrote full ciphertext directly" + ); + this.scratch.clear(); + return Poll::Ready(Ok(to_accept)); } + + // Partial upstream write of ciphertext: + // We accepted `to_accept` plaintext bytes, CTR already advanced for to_accept + // Must buffer the remainder ciphertext + warn!( + accepted = to_accept, + ciphertext_len = this.scratch.len(), + written_ciphertext = n, + "CryptoWriter: partial upstream write, buffering remainder" + ); + + // Split off remainder without copying + let remainder = this.scratch.split_off(n); + this.scratch.clear(); + + let pending = Self::ensure_pending(&mut this.state); + pending.replace_with(remainder); + + Poll::Ready(Ok(to_accept)) } } } - + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - - // First flush our pending buffer - match this.poll_flush_pending(cx)? { - Poll::Pending => return Poll::Pending, - Poll::Ready(false) => { - cx.waker().wake_by_ref(); - return Poll::Pending; - } - Poll::Ready(true) => {} + + if matches!(this.state, CryptoWriterState::Poisoned { .. }) { + let err = this.take_poison_error(); + return Poll::Ready(Err(err)); } - - // Then flush upstream + + match this.poll_flush_pending(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + } + Pin::new(&mut this.upstream).poll_flush(cx) } - + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - - // Try to flush pending data first (best effort) + + // Best-effort flush pending ciphertext before shutdown + // If upstream blocks, proceed to shutdown anyway match this.poll_flush_pending(cx) { Poll::Pending => { - // Continue with shutdown anyway after registering waker + debug!( + pending_len = this.pending_len(), + "CryptoWriter: shutdown with pending ciphertext (upstream Pending)" + ); } - Poll::Ready(Err(_)) => { - // Ignore flush errors during shutdown - } - Poll::Ready(Ok(_)) => {} + Poll::Ready(Err(_)) => {} + Poll::Ready(Ok(())) => {} } - - // Shutdown upstream + Pin::new(&mut this.upstream).poll_shutdown(cx) } } @@ -666,28 +809,24 @@ impl AsyncWrite for CryptoWriter { /// Passthrough stream for fast mode - no encryption/decryption /// /// Used when keys are set up so that client and Telegram use the same -/// encryption, allowing data to pass through without re-encryption. +/// encryption, allowing data to pass through without re-encryption pub struct PassthroughStream { inner: S, } impl PassthroughStream { - /// Create new passthrough stream pub fn new(inner: S) -> Self { Self { inner } } - - /// Get reference to inner stream + pub fn get_ref(&self) -> &S { &self.inner } - - /// Get mutable reference to inner stream + pub fn get_mut(&mut self) -> &mut S { &mut self.inner } - - /// Consume and return inner stream + pub fn into_inner(self) -> S { self.inner } @@ -711,443 +850,12 @@ impl AsyncWrite for PassthroughStream { ) -> Poll> { Pin::new(&mut self.inner).poll_write(cx, buf) } - + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.inner).poll_flush(cx) } - + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.inner).poll_shutdown(cx) } -} - -// ============= Tests ============= - -#[cfg(test)] -mod tests { - use super::*; - use std::collections::VecDeque; - use std::pin::Pin; - use std::task::{Context, Poll, Waker, RawWaker, RawWakerVTable}; - use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; - - // ============= Test Helpers ============= - - fn noop_waker() -> Waker { - const VTABLE: RawWakerVTable = RawWakerVTable::new( - |_| RawWaker::new(std::ptr::null(), &VTABLE), - |_| {}, - |_| {}, - |_| {}, - ); - unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) } - } - - /// Mock writer that simulates partial writes - struct PartialWriter { - /// Max bytes to accept per write - chunk_size: usize, - /// Collected data - data: Vec, - /// Number of writes performed - write_count: usize, - /// If true, return Pending on first write attempt - first_pending: bool, - /// Track if first call happened - first_call: bool, - } - - impl PartialWriter { - fn new(chunk_size: usize) -> Self { - Self { - chunk_size, - data: Vec::new(), - write_count: 0, - first_pending: false, - first_call: true, - } - } - - fn with_first_pending(mut self) -> Self { - self.first_pending = true; - self - } - } - - impl AsyncWrite for PartialWriter { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - if self.first_pending && self.first_call { - self.first_call = false; - cx.waker().wake_by_ref(); - return Poll::Pending; - } - - self.write_count += 1; - let to_write = buf.len().min(self.chunk_size); - self.data.extend_from_slice(&buf[..to_write]); - Poll::Ready(Ok(to_write)) - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - } - - /// Mock reader that returns data in chunks - struct ChunkedReader { - data: VecDeque, - chunk_size: usize, - } - - impl ChunkedReader { - fn new(data: &[u8], chunk_size: usize) -> Self { - Self { - data: data.iter().copied().collect(), - chunk_size, - } - } - } - - impl AsyncRead for ChunkedReader { - fn poll_read( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - if self.data.is_empty() { - return Poll::Ready(Ok(())); - } - - let to_read = self.chunk_size.min(self.data.len()).min(buf.remaining()); - for _ in 0..to_read { - if let Some(byte) = self.data.pop_front() { - buf.put_slice(&[byte]); - } - } - - Poll::Ready(Ok(())) - } - } - - // ============= CryptoReader Tests ============= - - #[tokio::test] - async fn test_crypto_reader_basic() { - let key = [0x42u8; 32]; - let iv = 12345u128; - - // Encrypt some data - let original = b"Hello, encrypted world!"; - let mut encryptor = AesCtr::new(&key, iv); - let encrypted = encryptor.encrypt(original); - - // Create reader - let reader = ChunkedReader::new(&encrypted, 100); - let decryptor = AesCtr::new(&key, iv); - let mut crypto_reader = CryptoReader::new(reader, decryptor); - - // Read and decrypt - let mut buf = vec![0u8; original.len()]; - crypto_reader.read_exact(&mut buf).await.unwrap(); - - assert_eq!(&buf, original); - } - - #[tokio::test] - async fn test_crypto_reader_chunked() { - let key = [0x42u8; 32]; - let iv = 12345u128; - - let original = b"This is a longer message that will be read in chunks"; - let mut encryptor = AesCtr::new(&key, iv); - let encrypted = encryptor.encrypt(original); - - // Read in very small chunks - let reader = ChunkedReader::new(&encrypted, 5); - let decryptor = AesCtr::new(&key, iv); - let mut crypto_reader = CryptoReader::new(reader, decryptor); - - let mut result = Vec::new(); - let mut buf = [0u8; 7]; // Read in chunks different from write chunks - - loop { - let n = crypto_reader.read(&mut buf).await.unwrap(); - if n == 0 { - break; - } - result.extend_from_slice(&buf[..n]); - } - - assert_eq!(&result, original); - } - - #[tokio::test] - async fn test_crypto_reader_read_exact_decrypt() { - let key = [0x42u8; 32]; - let iv = 12345u128; - - let original = b"Exact read test data!"; - let mut encryptor = AesCtr::new(&key, iv); - let encrypted = encryptor.encrypt(original); - - let reader = ChunkedReader::new(&encrypted, 3); // Small chunks - let decryptor = AesCtr::new(&key, iv); - let mut crypto_reader = CryptoReader::new(reader, decryptor); - - let result = crypto_reader.read_exact_decrypt(original.len()).await.unwrap(); - assert_eq!(&result[..], original); - } - - // ============= CryptoWriter Tests ============= - - #[test] - fn test_crypto_writer_basic_sync() { - let key = [0x42u8; 32]; - let iv = 12345u128; - - let mock_writer = PartialWriter::new(100); - let encryptor = AesCtr::new(&key, iv); - let mut crypto_writer = CryptoWriter::new(mock_writer, encryptor); - - let waker = noop_waker(); - let mut cx = Context::from_waker(&waker); - - let original = b"Hello, world!"; - - // Write - let result = Pin::new(&mut crypto_writer).poll_write(&mut cx, original); - assert!(matches!(result, Poll::Ready(Ok(13)))); - - // Verify encryption happened - let encrypted = &crypto_writer.upstream.data; - assert_eq!(encrypted.len(), original.len()); - assert_ne!(encrypted.as_slice(), original); // Should be encrypted - - // Decrypt and verify - let mut decryptor = AesCtr::new(&key, iv); - let mut decrypted = encrypted.clone(); - decryptor.apply(&mut decrypted); - assert_eq!(&decrypted, original); - } - - #[test] - fn test_crypto_writer_partial_write() { - let key = [0x42u8; 32]; - let iv = 12345u128; - - // Writer that only accepts 5 bytes at a time - let mock_writer = PartialWriter::new(5); - let encryptor = AesCtr::new(&key, iv); - let mut crypto_writer = CryptoWriter::new(mock_writer, encryptor); - - let waker = noop_waker(); - let mut cx = Context::from_waker(&waker); - - let original = b"This is a longer message!"; // 25 bytes - - // First write - should accept all 25 bytes (5 written, 20 buffered) - let result = Pin::new(&mut crypto_writer).poll_write(&mut cx, original); - assert!(matches!(result, Poll::Ready(Ok(25)))); - - // Should have pending data - assert!(crypto_writer.has_pending()); - - // Flush to drain pending - loop { - match Pin::new(&mut crypto_writer).poll_flush(&mut cx) { - Poll::Ready(Ok(())) => break, - Poll::Ready(Err(e)) => panic!("Flush error: {}", e), - Poll::Pending => continue, - } - } - - // All data should be written now - assert!(!crypto_writer.has_pending()); - assert_eq!(crypto_writer.upstream.data.len(), 25); - - // Verify decryption - let mut decryptor = AesCtr::new(&key, iv); - let mut decrypted = crypto_writer.upstream.data.clone(); - decryptor.apply(&mut decrypted); - assert_eq!(&decrypted, original); - } - - #[test] - fn test_crypto_writer_pending_on_first_write() { - let key = [0x42u8; 32]; - let iv = 12345u128; - - // Writer that returns Pending on first call - let mock_writer = PartialWriter::new(100).with_first_pending(); - let encryptor = AesCtr::new(&key, iv); - let mut crypto_writer = CryptoWriter::new(mock_writer, encryptor); - - let waker = noop_waker(); - let mut cx = Context::from_waker(&waker); - - let original = b"Test data"; - - // First write should buffer and return Ready (not Pending) - // because we have buffer space - let result = Pin::new(&mut crypto_writer).poll_write(&mut cx, original); - assert!(matches!(result, Poll::Ready(Ok(9)))); - - // Data should be buffered - assert!(crypto_writer.has_pending()); - - // Second poll_flush should succeed - loop { - match Pin::new(&mut crypto_writer).poll_flush(&mut cx) { - Poll::Ready(Ok(())) => break, - Poll::Ready(Err(e)) => panic!("Flush error: {}", e), - Poll::Pending => continue, - } - } - } - - #[tokio::test] - async fn test_crypto_stream_roundtrip() { - let key = [0u8; 32]; - let iv = 12345u128; - - let (client, server) = duplex(4096); - - let encryptor = AesCtr::new(&key, iv); - let decryptor = AesCtr::new(&key, iv); - - let mut writer = CryptoWriter::new(client, encryptor); - let mut reader = CryptoReader::new(server, decryptor); - - // Write - let original = b"Hello, encrypted world!"; - writer.write_all(original).await.unwrap(); - writer.flush().await.unwrap(); - - // Read - let mut buf = vec![0u8; original.len()]; - reader.read_exact(&mut buf).await.unwrap(); - - assert_eq!(&buf, original); - } - - #[tokio::test] - async fn test_crypto_stream_large_data() { - let key = [0x55u8; 32]; - let iv = 777u128; - - let (client, server) = duplex(1024); - - let encryptor = AesCtr::new(&key, iv); - let decryptor = AesCtr::new(&key, iv); - - let mut writer = CryptoWriter::new(client, encryptor); - let mut reader = CryptoReader::new(server, decryptor); - - // Large data - let original: Vec = (0..10000).map(|i| (i % 256) as u8).collect(); - - // Write in background - let write_data = original.clone(); - let write_handle = tokio::spawn(async move { - writer.write_all(&write_data).await.unwrap(); - writer.flush().await.unwrap(); - writer.shutdown().await.unwrap(); - }); - - // Read - let mut received = Vec::new(); - let mut buf = vec![0u8; 1024]; - loop { - match reader.read(&mut buf).await { - Ok(0) => break, - Ok(n) => received.extend_from_slice(&buf[..n]), - Err(e) => panic!("Read error: {}", e), - } - } - - write_handle.await.unwrap(); - - assert_eq!(received, original); - } - - #[tokio::test] - async fn test_crypto_writer_backpressure() { - let key = [0x42u8; 32]; - let iv = 12345u128; - - // Very small buffer duplex - let (client, _server) = duplex(64); - - let encryptor = AesCtr::new(&key, iv); - let mut writer = CryptoWriter::new(client, encryptor); - - // Try to write a lot of data - let large_data = vec![0u8; MAX_PENDING_WRITE + 1000]; - - // This should eventually block due to backpressure - // (duplex buffer full + our pending buffer full) - let write_result = tokio::time::timeout( - std::time::Duration::from_millis(100), - writer.write_all(&large_data) - ).await; - - // Should timeout because we can't write all data - assert!(write_result.is_err()); - } - - // ============= State Tests ============= - - #[test] - fn test_reader_state_transitions() { - let key = [0u8; 32]; - let iv = 0u128; - - let reader = ChunkedReader::new(&[], 10); - let decryptor = AesCtr::new(&key, iv); - let reader = CryptoReader::new(reader, decryptor); - - assert_eq!(reader.state_name(), "Idle"); - assert!(!reader.is_poisoned()); - } - - #[test] - fn test_writer_state_transitions() { - let key = [0u8; 32]; - let iv = 0u128; - - let writer = PartialWriter::new(10); - let encryptor = AesCtr::new(&key, iv); - let writer = CryptoWriter::new(writer, encryptor); - - assert_eq!(writer.state_name(), "Idle"); - assert!(!writer.is_poisoned()); - assert!(!writer.has_pending()); - } - - // ============= Passthrough Tests ============= - - #[tokio::test] - async fn test_passthrough_stream() { - let (client, server) = duplex(4096); - - let mut writer = PassthroughStream::new(client); - let mut reader = PassthroughStream::new(server); - - let data = b"No encryption here!"; - writer.write_all(data).await.unwrap(); - writer.flush().await.unwrap(); - - let mut buf = vec![0u8; data.len()]; - reader.read_exact(&mut buf).await.unwrap(); - - assert_eq!(&buf, data); - } } \ No newline at end of file diff --git a/src/stream/tls_stream.rs b/src/stream/tls_stream.rs index 287fdb8..a4edf58 100644 --- a/src/stream/tls_stream.rs +++ b/src/stream/tls_stream.rs @@ -1,17 +1,36 @@ //! Fake TLS 1.3 stream wrappers //! -//! This module provides stateful async stream wrappers that handle -//! TLS record framing with proper partial read/write handling. +//! This module provides stateful async stream wrappers that handle TLS record +//! framing with proper partial read/write handling. //! -//! These are "fake" TLS streams - they wrap data in valid TLS 1.3 -//! Application Data records but don't perform actual TLS encryption. -//! The actual encryption is handled by the crypto layer underneath. +//! These are "fake" TLS streams: +//! - We wrap raw bytes into syntactically valid TLS 1.3 records (Application Data). +//! - We DO NOT perform real TLS handshake/encryption. +//! - Real crypto for MTProto is handled by the crypto layer underneath. +//! +//! Why do we need this? +//! Telegram MTProto proxy "FakeTLS" mode uses a TLS-looking outer layer for +//! domain fronting / traffic camouflage. iOS Telegram clients are known to +//! produce slightly different TLS record sizing patterns than Android/Desktop, +//! including records that exceed 16384 payload bytes by a small overhead. //! //! Key design principles: //! - Explicit state machines for all async operations //! - Never lose data on partial reads //! - Atomic TLS record formation for writes //! - Proper handling of all TLS record types +//! +//! Important nuance (Telegram FakeTLS): +//! - The TLS spec limits "plaintext fragments" to 2^14 (16384) bytes. +//! - However, the on-the-wire record length can exceed 16384 because TLS 1.3 +//! uses AEAD and can include tag/overhead/padding. +//! - Telegram FakeTLS clients (notably iOS) may send Application Data records +//! with length up to 16384 + 24 bytes. We accept that as MAX_TLS_CHUNK_SIZE. +//! +//! If you reject those (e.g. validate length <= 16384), you will see errors like: +//! "TLS record too large: 16408 bytes" +//! and uploads from iOS will break (media/file sending), while small traffic +//! may still work. use bytes::{Bytes, BytesMut, BufMut}; use std::io::{self, Error, ErrorKind, Result}; @@ -20,25 +39,29 @@ use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf}; use crate::protocol::constants::{ - TLS_VERSION, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, - TLS_RECORD_HANDSHAKE, TLS_RECORD_ALERT, MAX_TLS_RECORD_SIZE, + TLS_VERSION, + TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, + TLS_RECORD_HANDSHAKE, TLS_RECORD_ALERT, + MAX_TLS_CHUNK_SIZE, }; use super::state::{StreamState, HeaderBuffer, YieldBuffer, WriteBuffer}; // ============= Constants ============= -/// TLS record header size +/// TLS record header size (type + version + length) const TLS_HEADER_SIZE: usize = 5; -/// Maximum TLS record payload size (16KB as per TLS spec) +/// Maximum TLS fragment size per spec (plaintext fragment). +/// We use this for *outgoing* chunking, because we build plain ApplicationData records. const MAX_TLS_PAYLOAD: usize = 16384; -/// Maximum pending write buffer +/// Maximum pending write buffer for one record remainder. +/// Note: we never queue unlimited amount of data here; state holds at most one record. const MAX_PENDING_WRITE: usize = 64 * 1024; // ============= TLS Record Types ============= -/// Parsed TLS record header +/// Parsed TLS record header (5 bytes) #[derive(Debug, Clone, Copy)] struct TlsRecordHeader { /// Record type (0x17 = Application Data, 0x14 = Change Cipher, etc.) @@ -50,50 +73,64 @@ struct TlsRecordHeader { } impl TlsRecordHeader { - /// Parse header from 5 bytes + /// Parse header from exactly 5 bytes. + /// + /// This currently never returns None, but is kept as Option to allow future + /// stricter parsing rules without changing callers. fn parse(header: &[u8; 5]) -> Option { let record_type = header[0]; let version = [header[1], header[2]]; let length = u16::from_be_bytes([header[3], header[4]]); - - Some(Self { - record_type, - version, - length, - }) + Some(Self { record_type, version, length }) } - - /// Validate the header + + /// Validate the header. + /// + /// Nuances: + /// - We accept TLS 1.0 header version for ClientHello-like records (0x03 0x01), + /// and TLS 1.2/1.3 style version bytes for the rest (we use TLS_VERSION = 0x03 0x03). + /// - For Application Data, Telegram FakeTLS may send payload length up to + /// MAX_TLS_CHUNK_SIZE (16384 + 24). + /// - For other record types we keep stricter bounds to avoid memory abuse. fn validate(&self) -> Result<()> { - // Check version (accept TLS 1.0 for ClientHello, TLS 1.2/1.3 for others) + // Version: accept TLS 1.0 header (ClientHello quirk) and TLS_VERSION (0x0303). if self.version != [0x03, 0x01] && self.version != TLS_VERSION { return Err(Error::new( ErrorKind::InvalidData, format!("Invalid TLS version: {:02x?}", self.version), )); } - - // Check length - if self.length as usize > MAX_TLS_RECORD_SIZE { - return Err(Error::new( - ErrorKind::InvalidData, - format!("TLS record too large: {} bytes", self.length), - )); + + let len = self.length as usize; + + // Length checks depend on record type. + // Telegram FakeTLS: ApplicationData length may be 16384 + 24. + match self.record_type { + TLS_RECORD_APPLICATION => { + if len > MAX_TLS_CHUNK_SIZE { + return Err(Error::new( + ErrorKind::InvalidData, + format!("TLS record too large: {} bytes (max {})", len, MAX_TLS_CHUNK_SIZE), + )); + } + } + + // ChangeCipherSpec/Alert/Handshake should never be that large for our usage + // (post-handshake we don't expect Handshake at all). + // Keep strict to reduce attack surface. + _ => { + if len > MAX_TLS_PAYLOAD { + return Err(Error::new( + ErrorKind::InvalidData, + format!("TLS control record too large: {} bytes (max {})", len, MAX_TLS_PAYLOAD), + )); + } + } } - + Ok(()) } - - /// Check if this is an application data record - fn is_application_data(&self) -> bool { - self.record_type == TLS_RECORD_APPLICATION - } - - /// Check if this is a change cipher spec record (should be skipped) - fn is_change_cipher_spec(&self) -> bool { - self.record_type == TLS_RECORD_CHANGE_CIPHER - } - + /// Build header bytes fn to_bytes(&self) -> [u8; 5] { [ @@ -113,32 +150,27 @@ impl TlsRecordHeader { enum TlsReaderState { /// Ready to read a new TLS record Idle, - + /// Reading the 5-byte TLS record header ReadingHeader { /// Header buffer (5 bytes) header: HeaderBuffer, }, - - /// Reading the TLS record body + + /// Reading the TLS record body (payload) ReadingBody { - /// Parsed record type record_type: u8, - /// Total body length length: usize, - /// Buffer for body data buffer: BytesMut, }, - - /// Have decrypted data ready to yield to caller + + /// Have buffered data ready to yield to caller Yielding { - /// Buffer containing data to yield buffer: YieldBuffer, }, - + /// Stream encountered an error and cannot be used Poisoned { - /// The error that caused poisoning error: Option, }, } @@ -147,11 +179,11 @@ impl StreamState for TlsReaderState { fn is_terminal(&self) -> bool { matches!(self, Self::Poisoned { .. }) } - + fn is_poisoned(&self) -> bool { matches!(self, Self::Poisoned { .. }) } - + fn state_name(&self) -> &'static str { match self { Self::Idle => "Idle", @@ -165,12 +197,13 @@ impl StreamState for TlsReaderState { // ============= FakeTlsReader ============= -/// Reader that unwraps TLS 1.3 records with proper state machine +/// Reader that unwraps TLS records (FakeTLS). /// -/// This reader handles partial reads correctly by maintaining internal state -/// and never losing any data that has been read from upstream. +/// This wrapper is responsible ONLY for TLS record framing and skipping +/// non-data records (like CCS). It does not decrypt TLS: payload bytes are passed +/// as-is to upper layers (crypto stream). /// -/// # State Machine +/// State machine overview: /// /// ┌──────────┐ ┌───────────────┐ /// │ Idle │ -----------------> │ ReadingHeader │ @@ -178,103 +211,69 @@ impl StreamState for TlsReaderState { /// ▲ │ /// │ header complete /// │ │ -/// │ │ +/// │ ▼ /// │ ┌───────────────┐ /// │ skip record │ ReadingBody │ /// │ <-------- (CCS) -------- │ │ /// │ └───────┬───────┘ /// │ │ /// │ body complete -/// │ drained │ -/// │ <-----------------┐ │ -/// │ │ ┌───────────────┐ -/// │ └----- │ Yielding │ +/// │ ▼ +/// │ ┌───────────────┐ +/// │ │ Yielding │ /// │ └───────────────┘ /// │ -/// │ errors /w any state -/// │ +/// │ errors / w any state +/// ▼ /// ┌───────────────────────────────────────────────┐ /// │ Poisoned │ /// └───────────────────────────────────────────────┘ /// +/// NOTE: We must correctly handle partial reads from upstream: +/// - do not assume header arrives in one poll +/// - do not assume body arrives in one poll +/// - never lose already-read bytes pub struct FakeTlsReader { - /// Upstream reader upstream: R, - /// Current state state: TlsReaderState, } impl FakeTlsReader { - /// Create new fake TLS reader pub fn new(upstream: R) -> Self { - Self { - upstream, - state: TlsReaderState::Idle, - } + Self { upstream, state: TlsReaderState::Idle } } - - /// Get reference to upstream - pub fn get_ref(&self) -> &R { - &self.upstream - } - - /// Get mutable reference to upstream - pub fn get_mut(&mut self) -> &mut R { - &mut self.upstream - } - - /// Consume and return upstream - pub fn into_inner(self) -> R { - self.upstream - } - - /// Check if stream is in poisoned state - pub fn is_poisoned(&self) -> bool { - self.state.is_poisoned() - } - - /// Get current state name (for debugging) - pub fn state_name(&self) -> &'static str { - self.state.state_name() - } - - /// Transition to poisoned state + + pub fn get_ref(&self) -> &R { &self.upstream } + pub fn get_mut(&mut self) -> &mut R { &mut self.upstream } + pub fn into_inner(self) -> R { self.upstream } + + pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() } + pub fn state_name(&self) -> &'static str { self.state.state_name() } + fn poison(&mut self, error: io::Error) { self.state = TlsReaderState::Poisoned { error: Some(error) }; } - - /// Take error from poisoned state + fn take_poison_error(&mut self) -> io::Error { match &mut self.state { - TlsReaderState::Poisoned { error } => { - error.take().unwrap_or_else(|| { - io::Error::new(ErrorKind::Other, "stream previously poisoned") - }) - } + TlsReaderState::Poisoned { error } => error.take().unwrap_or_else(|| { + io::Error::new(ErrorKind::Other, "stream previously poisoned") + }), _ => io::Error::new(ErrorKind::Other, "stream not poisoned"), } } } -/// Result of polling for header completion enum HeaderPollResult { - /// Need more data Pending, - /// EOF at record boundary (clean close) Eof, - /// Header complete, parsed successfully Complete(TlsRecordHeader), - /// Error occurred Error(io::Error), } -/// Result of polling for body completion enum BodyPollResult { - /// Need more data Pending, - /// Body complete Complete(Bytes), - /// Error occurred Error(io::Error), } @@ -285,13 +284,13 @@ impl AsyncRead for FakeTlsReader { buf: &mut ReadBuf<'_>, ) -> Poll> { let this = self.get_mut(); - + loop { // Take ownership of state to avoid borrow conflicts let state = std::mem::replace(&mut this.state, TlsReaderState::Idle); - + match state { - // Poisoned state - return error + // Poisoned state: always return the stored error TlsReaderState::Poisoned { error } => { this.state = TlsReaderState::Poisoned { error: None }; let err = error.unwrap_or_else(|| { @@ -299,55 +298,52 @@ impl AsyncRead for FakeTlsReader { }); return Poll::Ready(Err(err)); } - - // Have buffered data to yield + + // Yield buffered plaintext to caller TlsReaderState::Yielding { mut buffer } => { if buf.remaining() == 0 { this.state = TlsReaderState::Yielding { buffer }; return Poll::Ready(Ok(())); } - - // Copy as much as possible to output + let to_copy = buffer.remaining().min(buf.remaining()); let dst = buf.initialize_unfilled_to(to_copy); let copied = buffer.copy_to(dst); buf.advance(copied); - - // If buffer is drained, transition to Idle + if buffer.is_empty() { this.state = TlsReaderState::Idle; } else { this.state = TlsReaderState::Yielding { buffer }; } - + return Poll::Ready(Ok(())); } - - // Ready to read a new TLS record + + // Start reading new record TlsReaderState::Idle => { if buf.remaining() == 0 { this.state = TlsReaderState::Idle; return Poll::Ready(Ok(())); } - - // Start reading header + this.state = TlsReaderState::ReadingHeader { header: HeaderBuffer::new(), }; - // Continue to ReadingHeader + // loop continues and will handle ReadingHeader } - - // Reading TLS record header + + // Read TLS header (5 bytes) TlsReaderState::ReadingHeader { mut header } => { - // Poll to fill header let result = poll_read_header(&mut this.upstream, cx, &mut header); - + match result { HeaderPollResult::Pending => { this.state = TlsReaderState::ReadingHeader { header }; return Poll::Pending; } HeaderPollResult::Eof => { + // Clean EOF at record boundary this.state = TlsReaderState::Idle; return Poll::Ready(Ok(())); } @@ -356,15 +352,12 @@ impl AsyncRead for FakeTlsReader { return Poll::Ready(Err(e)); } HeaderPollResult::Complete(parsed) => { - // Validate header if let Err(e) = parsed.validate() { this.poison(Error::new(e.kind(), e.to_string())); return Poll::Ready(Err(e)); } - + let length = parsed.length as usize; - - // Transition to reading body this.state = TlsReaderState::ReadingBody { record_type: parsed.record_type, length, @@ -373,11 +366,11 @@ impl AsyncRead for FakeTlsReader { } } } - - // Reading TLS record body + + // Read TLS payload TlsReaderState::ReadingBody { record_type, length, mut buffer } => { let result = poll_read_body(&mut this.upstream, cx, &mut buffer, length); - + match result { BodyPollResult::Pending => { this.state = TlsReaderState::ReadingBody { record_type, length, buffer }; @@ -388,42 +381,43 @@ impl AsyncRead for FakeTlsReader { return Poll::Ready(Err(e)); } BodyPollResult::Complete(data) => { - // Handle different record types match record_type { TLS_RECORD_CHANGE_CIPHER => { - // Skip Change Cipher Spec, read next record + // CCS is expected in some clients, ignore it. this.state = TlsReaderState::Idle; continue; } + TLS_RECORD_APPLICATION => { - // Application data - yield to caller + // This is what we actually want. if data.is_empty() { this.state = TlsReaderState::Idle; continue; } - + this.state = TlsReaderState::Yielding { buffer: YieldBuffer::new(data), }; - // Continue to yield + // loop continues and will yield immediately } + TLS_RECORD_ALERT => { - // TLS Alert - treat as EOF + // Treat TLS alert as EOF-like termination. this.state = TlsReaderState::Idle; return Poll::Ready(Ok(())); } + TLS_RECORD_HANDSHAKE => { - let err = Error::new( - ErrorKind::InvalidData, - "unexpected TLS handshake record" - ); + // After FakeTLS handshake is done, we do not expect any Handshake records. + let err = Error::new(ErrorKind::InvalidData, "unexpected TLS handshake record"); this.poison(Error::new(err.kind(), err.to_string())); return Poll::Ready(Err(err)); } + _ => { let err = Error::new( ErrorKind::InvalidData, - format!("unknown TLS record type: 0x{:02x}", record_type) + format!("unknown TLS record type: 0x{:02x}", record_type), ); this.poison(Error::new(err.kind(), err.to_string())); return Poll::Ready(Err(err)); @@ -446,7 +440,7 @@ fn poll_read_header( while !header.is_complete() { let unfilled = header.unfilled_mut(); let mut read_buf = ReadBuf::new(unfilled); - + match Pin::new(&mut *upstream).poll_read(cx, &mut read_buf) { Poll::Pending => return HeaderPollResult::Pending, Poll::Ready(Err(e)) => return HeaderPollResult::Error(e), @@ -459,8 +453,10 @@ fn poll_read_header( } else { return HeaderPollResult::Error(Error::new( ErrorKind::UnexpectedEof, - format!("unexpected EOF in TLS header (got {} of 5 bytes)", - header.as_slice().len()) + format!( + "unexpected EOF in TLS header (got {} of 5 bytes)", + header.as_slice().len() + ), )); } } @@ -468,15 +464,11 @@ fn poll_read_header( } } } - - // Parse header + let header_bytes = *header.as_array(); match TlsRecordHeader::parse(&header_bytes) { Some(h) => HeaderPollResult::Complete(h), - None => HeaderPollResult::Error(Error::new( - ErrorKind::InvalidData, - "failed to parse TLS header" - )), + None => HeaderPollResult::Error(Error::new(ErrorKind::InvalidData, "failed to parse TLS header")), } } @@ -487,13 +479,15 @@ fn poll_read_body( buffer: &mut BytesMut, target_len: usize, ) -> BodyPollResult { + // NOTE: This implementation uses a temporary Vec to avoid tricky borrow/lifetime + // issues with BytesMut spare capacity and ReadBuf across polls. + // It's safe and correct; optimization is possible if needed. while buffer.len() < target_len { let remaining = target_len - buffer.len(); - - // Read into a temporary buffer + let mut temp = vec![0u8; remaining.min(8192)]; let mut read_buf = ReadBuf::new(&mut temp); - + match Pin::new(&mut *upstream).poll_read(cx, &mut read_buf) { Poll::Pending => return BodyPollResult::Pending, Poll::Ready(Err(e)) => return BodyPollResult::Error(e), @@ -502,67 +496,65 @@ fn poll_read_body( if n == 0 { return BodyPollResult::Error(Error::new( ErrorKind::UnexpectedEof, - format!("unexpected EOF in TLS body (got {} of {} bytes)", - buffer.len(), target_len) + format!( + "unexpected EOF in TLS body (got {} of {} bytes)", + buffer.len(), + target_len + ), )); } buffer.extend_from_slice(&temp[..n]); } } } - + BodyPollResult::Complete(buffer.split().freeze()) } impl FakeTlsReader { - /// Read exactly n bytes through TLS layer + /// Read exactly n bytes through TLS layer. /// - /// This is a convenience method that accumulates data across - /// multiple TLS records until exactly n bytes are available. + /// This accumulates data across multiple TLS ApplicationData records. pub async fn read_exact(&mut self, n: usize) -> Result { if self.is_poisoned() { return Err(self.take_poison_error()); } - + let mut result = BytesMut::with_capacity(n); - + while result.len() < n { let mut buf = vec![0u8; n - result.len()]; let read = AsyncReadExt::read(self, &mut buf).await?; - + if read == 0 { return Err(Error::new( ErrorKind::UnexpectedEof, - format!("expected {} bytes, got {}", n, result.len()) + format!("expected {} bytes, got {}", n, result.len()), )); } - + result.extend_from_slice(&buf[..read]); } - + Ok(result.freeze()) } } // ============= FakeTlsWriter State ============= -/// State machine states for FakeTlsWriter #[derive(Debug)] enum TlsWriterState { /// Ready to accept new data Idle, - - /// Writing a complete TLS record + + /// Writing a complete TLS record (header + body), possibly partially WritingRecord { - /// Complete record (header + body) to write record: WriteBuffer, - /// Original payload size (for return value calculation) payload_size: usize, }, - + /// Stream encountered an error and cannot be used Poisoned { - /// The error that caused poisoning error: Option, }, } @@ -571,11 +563,11 @@ impl StreamState for TlsWriterState { fn is_terminal(&self) -> bool { matches!(self, Self::Poisoned { .. }) } - + fn is_poisoned(&self) -> bool { matches!(self, Self::Poisoned { .. }) } - + fn state_name(&self) -> &'static str { match self { Self::Idle => "Idle", @@ -587,101 +579,53 @@ impl StreamState for TlsWriterState { // ============= FakeTlsWriter ============= -/// Writer that wraps data in TLS 1.3 records with proper state machine +/// Writer that wraps bytes into TLS 1.3 Application Data records. /// -/// This writer handles partial writes correctly by: -/// - Building complete TLS records before writing -/// - Maintaining internal state for partial record writes -/// - Never splitting a record mid-write to upstream -/// -/// # State Machine -/// -/// ┌──────────┐ write ┌─────────────────┐ -/// │ Idle │ -------------> │ WritingRecord │ -/// │ │ <------------- │ │ -/// └──────────┘ complete └─────────────────┘ -/// │ │ -/// │ < errors > │ -/// │ │ -/// ┌─────────────────────────────────────────────┐ -/// │ Poisoned │ -/// └─────────────────────────────────────────────┘ -/// -/// # Record Formation -/// -/// Data is chunked into records of at most MAX_TLS_PAYLOAD bytes. -/// Each record has a 5-byte header prepended. +/// We chunk outgoing data into records of <= 16384 payload bytes (MAX_TLS_PAYLOAD). +/// We do not try to mimic AEAD overhead on the wire; Telegram clients accept it. +/// If you want to be more camouflage-accurate later, you could add optional padding +/// to produce records sized closer to MAX_TLS_CHUNK_SIZE. pub struct FakeTlsWriter { - /// Upstream writer upstream: W, - /// Current state state: TlsWriterState, } impl FakeTlsWriter { - /// Create new fake TLS writer pub fn new(upstream: W) -> Self { - Self { - upstream, - state: TlsWriterState::Idle, - } + Self { upstream, state: TlsWriterState::Idle } } - - /// Get reference to upstream - pub fn get_ref(&self) -> &W { - &self.upstream - } - - /// Get mutable reference to upstream - pub fn get_mut(&mut self) -> &mut W { - &mut self.upstream - } - - /// Consume and return upstream - pub fn into_inner(self) -> W { - self.upstream - } - - /// Check if stream is in poisoned state - pub fn is_poisoned(&self) -> bool { - self.state.is_poisoned() - } - - /// Get current state name (for debugging) - pub fn state_name(&self) -> &'static str { - self.state.state_name() - } - - /// Check if there's a pending record to write + + pub fn get_ref(&self) -> &W { &self.upstream } + pub fn get_mut(&mut self) -> &mut W { &mut self.upstream } + pub fn into_inner(self) -> W { self.upstream } + + pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() } + pub fn state_name(&self) -> &'static str { self.state.state_name() } + pub fn has_pending(&self) -> bool { matches!(&self.state, TlsWriterState::WritingRecord { record, .. } if !record.is_empty()) } - - /// Transition to poisoned state + fn poison(&mut self, error: io::Error) { self.state = TlsWriterState::Poisoned { error: Some(error) }; } - - /// Take error from poisoned state + fn take_poison_error(&mut self) -> io::Error { match &mut self.state { - TlsWriterState::Poisoned { error } => { - error.take().unwrap_or_else(|| { - io::Error::new(ErrorKind::Other, "stream previously poisoned") - }) - } + TlsWriterState::Poisoned { error } => error.take().unwrap_or_else(|| { + io::Error::new(ErrorKind::Other, "stream previously poisoned") + }), _ => io::Error::new(ErrorKind::Other, "stream not poisoned"), } } - - /// Build a TLS Application Data record + fn build_record(data: &[u8]) -> BytesMut { let header = TlsRecordHeader { record_type: TLS_RECORD_APPLICATION, version: TLS_VERSION, length: data.len() as u16, }; - + let mut record = BytesMut::with_capacity(TLS_HEADER_SIZE + data.len()); record.extend_from_slice(&header.to_bytes()); record.extend_from_slice(data); @@ -689,18 +633,13 @@ impl FakeTlsWriter { } } -/// Result of flushing pending record enum FlushResult { - /// All data flushed, returns payload size Complete(usize), - /// Need to wait for upstream Pending, - /// Error occurred Error(io::Error), } impl FakeTlsWriter { - /// Try to flush pending record to upstream (standalone logic) fn poll_flush_record_inner( upstream: &mut W, cx: &mut Context<'_>, @@ -710,22 +649,17 @@ impl FakeTlsWriter { let data = record.pending(); match Pin::new(&mut *upstream).poll_write(cx, data) { Poll::Pending => return FlushResult::Pending, - Poll::Ready(Err(e)) => return FlushResult::Error(e), - Poll::Ready(Ok(0)) => { return FlushResult::Error(Error::new( ErrorKind::WriteZero, - "upstream returned 0 bytes written" + "upstream returned 0 bytes written", )); } - - Poll::Ready(Ok(n)) => { - record.advance(n); - } + Poll::Ready(Ok(n)) => record.advance(n), } } - + FlushResult::Complete(0) } } @@ -737,10 +671,10 @@ impl AsyncWrite for FakeTlsWriter { buf: &[u8], ) -> Poll> { let this = self.get_mut(); - - // Take ownership of state + + // Take ownership of state to avoid borrow conflicts. let state = std::mem::replace(&mut this.state, TlsWriterState::Idle); - + match state { TlsWriterState::Poisoned { error } => { this.state = TlsWriterState::Poisoned { error: None }; @@ -749,9 +683,9 @@ impl AsyncWrite for FakeTlsWriter { }); return Poll::Ready(Err(err)); } - + TlsWriterState::WritingRecord { mut record, payload_size } => { - // Continue flushing existing record + // Finish writing previous record before accepting new bytes. match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) { FlushResult::Pending => { this.state = TlsWriterState::WritingRecord { record, payload_size }; @@ -763,79 +697,76 @@ impl AsyncWrite for FakeTlsWriter { } FlushResult::Complete(_) => { this.state = TlsWriterState::Idle; - // Fall through to handle new write + // continue to accept new buf below } } } - + TlsWriterState::Idle => { this.state = TlsWriterState::Idle; } } - + // Now in Idle state if buf.is_empty() { return Poll::Ready(Ok(0)); } - + // Chunk to maximum TLS payload size let chunk_size = buf.len().min(MAX_TLS_PAYLOAD); let chunk = &buf[..chunk_size]; - - // Build the complete record + + // Build the complete record (header + payload) let record_data = Self::build_record(chunk); - - // Try to write directly first + match Pin::new(&mut this.upstream).poll_write(cx, &record_data) { Poll::Ready(Ok(n)) if n == record_data.len() => { - // Complete record written Poll::Ready(Ok(chunk_size)) } - + Poll::Ready(Ok(n)) => { - // Partial write - buffer the rest + // Partial write of the record: store remainder. let mut write_buffer = WriteBuffer::with_max_size(MAX_PENDING_WRITE); + // record_data length is <= 16389, fits MAX_PENDING_WRITE let _ = write_buffer.extend(&record_data[n..]); - + this.state = TlsWriterState::WritingRecord { record: write_buffer, payload_size: chunk_size, }; - - // We've accepted chunk_size bytes from caller + + // We have accepted chunk_size bytes from caller. Poll::Ready(Ok(chunk_size)) } - + Poll::Ready(Err(e)) => { this.poison(Error::new(e.kind(), e.to_string())); Poll::Ready(Err(e)) } - + Poll::Pending => { - // Buffer the entire record + // Buffer entire record and report success for this chunk. let mut write_buffer = WriteBuffer::with_max_size(MAX_PENDING_WRITE); let _ = write_buffer.extend(&record_data); - + this.state = TlsWriterState::WritingRecord { record: write_buffer, payload_size: chunk_size, }; - - // Wake to try again + + // Wake to retry flushing soon. cx.waker().wake_by_ref(); - - // We've accepted chunk_size bytes from caller + Poll::Ready(Ok(chunk_size)) } } } - + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - - // Take ownership of state + let state = std::mem::replace(&mut this.state, TlsWriterState::Idle); - + match state { TlsWriterState::Poisoned { error } => { this.state = TlsWriterState::Poisoned { error: None }; @@ -844,7 +775,7 @@ impl AsyncWrite for FakeTlsWriter { }); return Poll::Ready(Err(err)); } - + TlsWriterState::WritingRecord { mut record, payload_size } => { match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) { FlushResult::Pending => { @@ -860,64 +791,49 @@ impl AsyncWrite for FakeTlsWriter { } } } - + TlsWriterState::Idle => { this.state = TlsWriterState::Idle; } } - - // Flush upstream + Pin::new(&mut this.upstream).poll_flush(cx) } - + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - - // Take ownership of state + let state = std::mem::replace(&mut this.state, TlsWriterState::Idle); - + match state { - TlsWriterState::WritingRecord { mut record, payload_size } => { - // Try to flush pending (best effort) - match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) { - FlushResult::Pending => { - // Can't complete flush, continue with shutdown anyway - this.state = TlsWriterState::Idle; - } - FlushResult::Error(_) => { - // Ignore errors during shutdown - this.state = TlsWriterState::Idle; - } - FlushResult::Complete(_) => { - this.state = TlsWriterState::Idle; - } - } + TlsWriterState::WritingRecord { mut record, payload_size: _ } => { + // Best-effort flush (do not block shutdown forever). + let _ = Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record); + this.state = TlsWriterState::Idle; } _ => { this.state = TlsWriterState::Idle; } } - - // Shutdown upstream + Pin::new(&mut this.upstream).poll_shutdown(cx) } } impl FakeTlsWriter { - /// Write all data wrapped in TLS records (async method) + /// Write all data wrapped in TLS records. /// - /// This convenience method handles chunking large data into - /// multiple TLS records automatically. + /// Convenience method that chunks into <= 16384 records. pub async fn write_all_tls(&mut self, data: &[u8]) -> Result<()> { let mut written = 0; while written < data.len() { let chunk_size = (data.len() - written).min(MAX_TLS_PAYLOAD); let chunk = &data[written..written + chunk_size]; - + AsyncWriteExt::write_all(self, chunk).await?; written += chunk_size; } - + self.flush().await } } diff --git a/src/transport/socket.rs b/src/transport/socket.rs index 09e5148..a07c21c 100644 --- a/src/transport/socket.rs +++ b/src/transport/socket.rs @@ -30,20 +30,13 @@ pub fn configure_tcp_socket( socket.set_tcp_keepalive(&keepalive)?; } - // Set buffer sizes - set_buffer_sizes(&socket, 65536, 65536)?; + // CHANGED: Removed manual buffer size setting (was 256KB). + // Allowing the OS kernel to handle TCP window scaling (Autotuning) is critical + // for mobile clients to avoid bufferbloat and stalled connections during uploads. Ok(()) } -/// Set socket buffer sizes -fn set_buffer_sizes(socket: &socket2::SockRef, recv: usize, send: usize) -> Result<()> { - // These may fail on some systems, so we ignore errors - let _ = socket.set_recv_buffer_size(recv); - let _ = socket.set_send_buffer_size(send); - Ok(()) -} - /// Configure socket for accepting client connections pub fn configure_client_socket( stream: &TcpStream, @@ -65,6 +58,8 @@ pub fn configure_client_socket( socket.set_tcp_keepalive(&keepalive)?; // Set TCP user timeout (Linux only) + // NOTE: iOS does not support TCP_USER_TIMEOUT - application-level timeout + // is implemented in relay_bidirectional instead #[cfg(target_os = "linux")] { use std::os::unix::io::AsRawFd;