7 Commits

Author SHA1 Message Date
Alexey
4fa6867056 Merge pull request #7 from telemt/1.0.3.0
1.0.3.0
2026-01-12 00:49:31 +03:00
Alexey
54ea6efdd0 Global rewrite of AES-CTR + Upstream Pending + to_accept selection 2026-01-12 00:46:51 +03:00
brekotis
27ac32a901 Fixes in TLS for iOS 2026-01-12 00:32:42 +03:00
Alexey
829f53c123 Fixes for iOS 2026-01-11 22:59:51 +03:00
Alexey
43eae6127d Update README.md 2026-01-10 22:17:03 +03:00
Alexey
a03212c8cc Update README.md 2026-01-10 22:15:02 +03:00
Alexey
2613969a7c Update rust.yml 2026-01-09 23:15:52 +03:00
11 changed files with 836 additions and 1117 deletions

View File

@@ -10,8 +10,8 @@ env:
CARGO_TERM_COLOR: always CARGO_TERM_COLOR: always
jobs: jobs:
build-and-test: build:
name: Build & Test name: Build
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:

View File

@@ -231,6 +231,10 @@ telemt config.toml
- Memory safety and reduced attack surface - Memory safety and reduced attack surface
- Tokio's asynchronous architecture - Tokio's asynchronous architecture
## Issues
- ✅ [SOCKS5 as Upstream](https://github.com/telemt/telemt/issues/1) -> added Upstream Management
- ⌛ [iOS - Media Upload Hanging-in-Loop](https://github.com/telemt/telemt/issues/2)
## Roadmap ## Roadmap
- Public IP in links - Public IP in links
- Config Reload-on-fly - Config Reload-on-fly

View File

@@ -163,9 +163,12 @@ fn default_port() -> u16 { 443 }
fn default_tls_domain() -> String { "www.google.com".to_string() } fn default_tls_domain() -> String { "www.google.com".to_string() }
fn default_mask_port() -> u16 { 443 } fn default_mask_port() -> u16 { 443 }
fn default_replay_check_len() -> usize { 65536 } fn default_replay_check_len() -> usize { 65536 }
fn default_handshake_timeout() -> u64 { 10 } // CHANGED: Increased handshake timeout for bad mobile networks
fn default_handshake_timeout() -> u64 { 15 }
fn default_connect_timeout() -> u64 { 10 } fn default_connect_timeout() -> u64 { 10 }
fn default_keepalive() -> u64 { 600 } // CHANGED: Reduced keepalive from 600s to 60s.
// Mobile NATs often drop idle connections after 60-120s.
fn default_keepalive() -> u64 { 60 }
fn default_ack_timeout() -> u64 { 300 } fn default_ack_timeout() -> u64 { 300 }
fn default_listen_addr() -> String { "0.0.0.0".to_string() } fn default_listen_addr() -> String { "0.0.0.0".to_string() }
fn default_fake_cert_len() -> usize { 2048 } fn default_fake_cert_len() -> usize { 2048 }

View File

@@ -20,7 +20,7 @@ mod util;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::proxy::ClientHandler; use crate::proxy::ClientHandler;
use crate::stats::Stats; use crate::stats::{Stats, ReplayChecker};
use crate::transport::{create_listener, ListenOptions, UpstreamManager}; use crate::transport::{create_listener, ListenOptions, UpstreamManager};
use crate::util::ip::detect_ip; use crate::util::ip::detect_ip;
@@ -55,6 +55,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let config = Arc::new(config); let config = Arc::new(config);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
// CHANGED: Initialize global ReplayChecker here instead of per-connection
let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len));
// Initialize Upstream Manager // Initialize Upstream Manager
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone())); let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
@@ -145,13 +148,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
} }
// Accept loop // Accept loop
// For simplicity in this slice, we just spawn a task for each listener
// In a real high-perf scenario, we might want a more complex accept loop
for listener in listeners { for listener in listeners {
let config = config.clone(); let config = config.clone();
let stats = stats.clone(); let stats = stats.clone();
let upstream_manager = upstream_manager.clone(); let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
@@ -160,6 +161,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let config = config.clone(); let config = config.clone();
let stats = stats.clone(); let stats = stats.clone();
let upstream_manager = upstream_manager.clone(); let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = ClientHandler::new( if let Err(e) = ClientHandler::new(
@@ -167,7 +169,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
peer_addr, peer_addr,
config, config,
stats, stats,
upstream_manager upstream_manager,
replay_checker // Pass global checker
).run().await { ).run().await {
// Log only relevant errors // Log only relevant errors
// debug!("Connection error: {}", e); // debug!("Connection error: {}", e);

View File

@@ -167,7 +167,10 @@ pub const DEFAULT_ACK_TIMEOUT_SECS: u64 = 300;
// ============= Buffer Sizes ============= // ============= Buffer Sizes =============
/// Default buffer size /// Default buffer size
pub const DEFAULT_BUFFER_SIZE: usize = 65536; /// CHANGED: Reduced from 64KB to 16KB to match TLS record size and align with
/// the new buffering strategy for better iOS upload performance.
pub const DEFAULT_BUFFER_SIZE: usize = 16384;
/// Small buffer size for bad client handling /// Small buffer size for bad client handling
pub const SMALL_BUFFER_SIZE: usize = 8192; pub const SMALL_BUFFER_SIZE: usize = 8192;

View File

@@ -45,11 +45,10 @@ impl ClientHandler {
config: Arc<ProxyConfig>, config: Arc<ProxyConfig>,
stats: Arc<Stats>, stats: Arc<Stats>,
upstream_manager: Arc<UpstreamManager>, upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>, // CHANGED: Accept global checker
) -> RunningClientHandler { ) -> RunningClientHandler {
// Note: ReplayChecker should be shared globally for proper replay protection // CHANGED: Removed local creation of ReplayChecker.
// Creating it per-connection disables replay protection across connections // It is now passed from main.rs to ensure global replay protection.
// TODO: Pass Arc<ReplayChecker> from main.rs
let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len));
RunningClientHandler { RunningClientHandler {
stream, stream,

View File

@@ -1,13 +1,21 @@
//! Bidirectional Relay //! Bidirectional Relay
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tracing::{debug, trace, warn}; use tokio::time::Instant;
use tracing::{debug, trace, warn, info};
use crate::error::Result; use crate::error::Result;
use crate::stats::Stats; use crate::stats::Stats;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
const BUFFER_SIZE: usize = 65536; // CHANGED: Reduced from 128KB to 16KB to match TLS record size and prevent bufferbloat.
// This is critical for iOS clients to maintain proper TCP flow control during uploads.
const BUFFER_SIZE: usize = 16384;
// Activity timeout for iOS compatibility (30 minutes)
// iOS does not support TCP_USER_TIMEOUT, so we implement application-level timeout
const ACTIVITY_TIMEOUT_SECS: u64 = 1800;
/// Relay data bidirectionally between client and server /// Relay data bidirectionally between client and server
pub async fn relay_bidirectional<CR, CW, SR, SW>( pub async fn relay_bidirectional<CR, CW, SR, SW>(
@@ -36,15 +44,40 @@ where
let c2s_bytes_clone = Arc::clone(&c2s_bytes); let c2s_bytes_clone = Arc::clone(&c2s_bytes);
let s2c_bytes_clone = Arc::clone(&s2c_bytes); let s2c_bytes_clone = Arc::clone(&s2c_bytes);
// Client -> Server task // Activity timeout for iOS compatibility
let activity_timeout = Duration::from_secs(ACTIVITY_TIMEOUT_SECS);
// Client -> Server task with activity timeout
let c2s = tokio::spawn(async move { let c2s = tokio::spawn(async move {
let mut buf = vec![0u8; BUFFER_SIZE]; let mut buf = vec![0u8; BUFFER_SIZE];
let mut total_bytes = 0u64; let mut total_bytes = 0u64;
let mut msg_count = 0u64; let mut msg_count = 0u64;
let mut last_activity = Instant::now();
let mut last_log = Instant::now();
loop { loop {
match client_reader.read(&mut buf).await { // Read with timeout to prevent infinite hang on iOS
Ok(0) => { let read_result = tokio::time::timeout(
activity_timeout,
client_reader.read(&mut buf)
).await;
match read_result {
// Timeout - no activity for too long
Err(_) => {
warn!(
user = %user_c2s,
total_bytes = total_bytes,
msgs = msg_count,
idle_secs = last_activity.elapsed().as_secs(),
"Activity timeout (C->S) - no data received"
);
let _ = server_writer.shutdown().await;
break;
}
// Read successful
Ok(Ok(0)) => {
debug!( debug!(
user = %user_c2s, user = %user_c2s,
total_bytes = total_bytes, total_bytes = total_bytes,
@@ -54,9 +87,11 @@ where
let _ = server_writer.shutdown().await; let _ = server_writer.shutdown().await;
break; break;
} }
Ok(n) => {
Ok(Ok(n)) => {
total_bytes += n as u64; total_bytes += n as u64;
msg_count += 1; msg_count += 1;
last_activity = Instant::now();
c2s_bytes_clone.store(total_bytes, Ordering::Relaxed); c2s_bytes_clone.store(total_bytes, Ordering::Relaxed);
stats_c2s.add_user_octets_from(&user_c2s, n as u64); stats_c2s.add_user_octets_from(&user_c2s, n as u64);
@@ -70,6 +105,19 @@ where
"C->S data" "C->S data"
); );
// Log activity every 10 seconds for large transfers
if last_log.elapsed() > Duration::from_secs(10) {
let rate = total_bytes as f64 / last_log.elapsed().as_secs_f64();
info!(
user = %user_c2s,
total_bytes = total_bytes,
msgs = msg_count,
rate_kbps = (rate / 1024.0) as u64,
"C->S transfer in progress"
);
last_log = Instant::now();
}
if let Err(e) = server_writer.write_all(&buf[..n]).await { if let Err(e) = server_writer.write_all(&buf[..n]).await {
debug!(user = %user_c2s, error = %e, "Failed to write to server"); debug!(user = %user_c2s, error = %e, "Failed to write to server");
break; break;
@@ -79,7 +127,8 @@ where
break; break;
} }
} }
Err(e) => {
Ok(Err(e)) => {
debug!(user = %user_c2s, error = %e, total_bytes = total_bytes, "Client read error"); debug!(user = %user_c2s, error = %e, total_bytes = total_bytes, "Client read error");
break; break;
} }
@@ -87,15 +136,37 @@ where
} }
}); });
// Server -> Client task // Server -> Client task with activity timeout
let s2c = tokio::spawn(async move { let s2c = tokio::spawn(async move {
let mut buf = vec![0u8; BUFFER_SIZE]; let mut buf = vec![0u8; BUFFER_SIZE];
let mut total_bytes = 0u64; let mut total_bytes = 0u64;
let mut msg_count = 0u64; let mut msg_count = 0u64;
let mut last_activity = Instant::now();
let mut last_log = Instant::now();
loop { loop {
match server_reader.read(&mut buf).await { // Read with timeout to prevent infinite hang on iOS
Ok(0) => { let read_result = tokio::time::timeout(
activity_timeout,
server_reader.read(&mut buf)
).await;
match read_result {
// Timeout - no activity for too long
Err(_) => {
warn!(
user = %user_s2c,
total_bytes = total_bytes,
msgs = msg_count,
idle_secs = last_activity.elapsed().as_secs(),
"Activity timeout (S->C) - no data received"
);
let _ = client_writer.shutdown().await;
break;
}
// Read successful
Ok(Ok(0)) => {
debug!( debug!(
user = %user_s2c, user = %user_s2c,
total_bytes = total_bytes, total_bytes = total_bytes,
@@ -105,9 +176,11 @@ where
let _ = client_writer.shutdown().await; let _ = client_writer.shutdown().await;
break; break;
} }
Ok(n) => {
Ok(Ok(n)) => {
total_bytes += n as u64; total_bytes += n as u64;
msg_count += 1; msg_count += 1;
last_activity = Instant::now();
s2c_bytes_clone.store(total_bytes, Ordering::Relaxed); s2c_bytes_clone.store(total_bytes, Ordering::Relaxed);
stats_s2c.add_user_octets_to(&user_s2c, n as u64); stats_s2c.add_user_octets_to(&user_s2c, n as u64);
@@ -121,6 +194,19 @@ where
"S->C data" "S->C data"
); );
// Log activity every 10 seconds for large transfers
if last_log.elapsed() > Duration::from_secs(10) {
let rate = total_bytes as f64 / last_log.elapsed().as_secs_f64();
info!(
user = %user_s2c,
total_bytes = total_bytes,
msgs = msg_count,
rate_kbps = (rate / 1024.0) as u64,
"S->C transfer in progress"
);
last_log = Instant::now();
}
if let Err(e) = client_writer.write_all(&buf[..n]).await { if let Err(e) = client_writer.write_all(&buf[..n]).await {
debug!(user = %user_s2c, error = %e, "Failed to write to client"); debug!(user = %user_s2c, error = %e, "Failed to write to client");
break; break;
@@ -130,7 +216,8 @@ where
break; break;
} }
} }
Err(e) => {
Ok(Err(e)) => {
debug!(user = %user_s2c, error = %e, total_bytes = total_bytes, "Server read error"); debug!(user = %user_s2c, error = %e, total_bytes = total_bytes, "Server read error");
break; break;
} }

View File

@@ -11,8 +11,9 @@ use std::sync::Arc;
// ============= Configuration ============= // ============= Configuration =============
/// Default buffer size (64KB - good for MTProto) /// Default buffer size
pub const DEFAULT_BUFFER_SIZE: usize = 64 * 1024; /// CHANGED: Reduced from 64KB to 16KB to match TLS record size and prevent bufferbloat.
pub const DEFAULT_BUFFER_SIZE: usize = 16 * 1024;
/// Default maximum number of pooled buffers /// Default maximum number of pooled buffers
pub const DEFAULT_MAX_BUFFERS: usize = 1024; pub const DEFAULT_MAX_BUFFERS: usize = 1024;

File diff suppressed because it is too large Load Diff

View File

@@ -1,17 +1,36 @@
//! Fake TLS 1.3 stream wrappers //! Fake TLS 1.3 stream wrappers
//! //!
//! This module provides stateful async stream wrappers that handle //! This module provides stateful async stream wrappers that handle TLS record
//! TLS record framing with proper partial read/write handling. //! framing with proper partial read/write handling.
//! //!
//! These are "fake" TLS streams - they wrap data in valid TLS 1.3 //! These are "fake" TLS streams:
//! Application Data records but don't perform actual TLS encryption. //! - We wrap raw bytes into syntactically valid TLS 1.3 records (Application Data).
//! The actual encryption is handled by the crypto layer underneath. //! - We DO NOT perform real TLS handshake/encryption.
//! - Real crypto for MTProto is handled by the crypto layer underneath.
//!
//! Why do we need this?
//! Telegram MTProto proxy "FakeTLS" mode uses a TLS-looking outer layer for
//! domain fronting / traffic camouflage. iOS Telegram clients are known to
//! produce slightly different TLS record sizing patterns than Android/Desktop,
//! including records that exceed 16384 payload bytes by a small overhead.
//! //!
//! Key design principles: //! Key design principles:
//! - Explicit state machines for all async operations //! - Explicit state machines for all async operations
//! - Never lose data on partial reads //! - Never lose data on partial reads
//! - Atomic TLS record formation for writes //! - Atomic TLS record formation for writes
//! - Proper handling of all TLS record types //! - Proper handling of all TLS record types
//!
//! Important nuance (Telegram FakeTLS):
//! - The TLS spec limits "plaintext fragments" to 2^14 (16384) bytes.
//! - However, the on-the-wire record length can exceed 16384 because TLS 1.3
//! uses AEAD and can include tag/overhead/padding.
//! - Telegram FakeTLS clients (notably iOS) may send Application Data records
//! with length up to 16384 + 24 bytes. We accept that as MAX_TLS_CHUNK_SIZE.
//!
//! If you reject those (e.g. validate length <= 16384), you will see errors like:
//! "TLS record too large: 16408 bytes"
//! and uploads from iOS will break (media/file sending), while small traffic
//! may still work.
use bytes::{Bytes, BytesMut, BufMut}; use bytes::{Bytes, BytesMut, BufMut};
use std::io::{self, Error, ErrorKind, Result}; use std::io::{self, Error, ErrorKind, Result};
@@ -20,25 +39,29 @@ use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf}; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf};
use crate::protocol::constants::{ use crate::protocol::constants::{
TLS_VERSION, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_VERSION,
TLS_RECORD_HANDSHAKE, TLS_RECORD_ALERT, MAX_TLS_RECORD_SIZE, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER,
TLS_RECORD_HANDSHAKE, TLS_RECORD_ALERT,
MAX_TLS_CHUNK_SIZE,
}; };
use super::state::{StreamState, HeaderBuffer, YieldBuffer, WriteBuffer}; use super::state::{StreamState, HeaderBuffer, YieldBuffer, WriteBuffer};
// ============= Constants ============= // ============= Constants =============
/// TLS record header size /// TLS record header size (type + version + length)
const TLS_HEADER_SIZE: usize = 5; const TLS_HEADER_SIZE: usize = 5;
/// Maximum TLS record payload size (16KB as per TLS spec) /// Maximum TLS fragment size per spec (plaintext fragment).
/// We use this for *outgoing* chunking, because we build plain ApplicationData records.
const MAX_TLS_PAYLOAD: usize = 16384; const MAX_TLS_PAYLOAD: usize = 16384;
/// Maximum pending write buffer /// Maximum pending write buffer for one record remainder.
/// Note: we never queue unlimited amount of data here; state holds at most one record.
const MAX_PENDING_WRITE: usize = 64 * 1024; const MAX_PENDING_WRITE: usize = 64 * 1024;
// ============= TLS Record Types ============= // ============= TLS Record Types =============
/// Parsed TLS record header /// Parsed TLS record header (5 bytes)
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
struct TlsRecordHeader { struct TlsRecordHeader {
/// Record type (0x17 = Application Data, 0x14 = Change Cipher, etc.) /// Record type (0x17 = Application Data, 0x14 = Change Cipher, etc.)
@@ -50,22 +73,27 @@ struct TlsRecordHeader {
} }
impl TlsRecordHeader { impl TlsRecordHeader {
/// Parse header from 5 bytes /// Parse header from exactly 5 bytes.
///
/// This currently never returns None, but is kept as Option to allow future
/// stricter parsing rules without changing callers.
fn parse(header: &[u8; 5]) -> Option<Self> { fn parse(header: &[u8; 5]) -> Option<Self> {
let record_type = header[0]; let record_type = header[0];
let version = [header[1], header[2]]; let version = [header[1], header[2]];
let length = u16::from_be_bytes([header[3], header[4]]); let length = u16::from_be_bytes([header[3], header[4]]);
Some(Self { record_type, version, length })
Some(Self {
record_type,
version,
length,
})
} }
/// Validate the header /// Validate the header.
///
/// Nuances:
/// - We accept TLS 1.0 header version for ClientHello-like records (0x03 0x01),
/// and TLS 1.2/1.3 style version bytes for the rest (we use TLS_VERSION = 0x03 0x03).
/// - For Application Data, Telegram FakeTLS may send payload length up to
/// MAX_TLS_CHUNK_SIZE (16384 + 24).
/// - For other record types we keep stricter bounds to avoid memory abuse.
fn validate(&self) -> Result<()> { fn validate(&self) -> Result<()> {
// Check version (accept TLS 1.0 for ClientHello, TLS 1.2/1.3 for others) // Version: accept TLS 1.0 header (ClientHello quirk) and TLS_VERSION (0x0303).
if self.version != [0x03, 0x01] && self.version != TLS_VERSION { if self.version != [0x03, 0x01] && self.version != TLS_VERSION {
return Err(Error::new( return Err(Error::new(
ErrorKind::InvalidData, ErrorKind::InvalidData,
@@ -73,27 +101,36 @@ impl TlsRecordHeader {
)); ));
} }
// Check length let len = self.length as usize;
if self.length as usize > MAX_TLS_RECORD_SIZE {
// Length checks depend on record type.
// Telegram FakeTLS: ApplicationData length may be 16384 + 24.
match self.record_type {
TLS_RECORD_APPLICATION => {
if len > MAX_TLS_CHUNK_SIZE {
return Err(Error::new( return Err(Error::new(
ErrorKind::InvalidData, ErrorKind::InvalidData,
format!("TLS record too large: {} bytes", self.length), format!("TLS record too large: {} bytes (max {})", len, MAX_TLS_CHUNK_SIZE),
)); ));
} }
}
// ChangeCipherSpec/Alert/Handshake should never be that large for our usage
// (post-handshake we don't expect Handshake at all).
// Keep strict to reduce attack surface.
_ => {
if len > MAX_TLS_PAYLOAD {
return Err(Error::new(
ErrorKind::InvalidData,
format!("TLS control record too large: {} bytes (max {})", len, MAX_TLS_PAYLOAD),
));
}
}
}
Ok(()) Ok(())
} }
/// Check if this is an application data record
fn is_application_data(&self) -> bool {
self.record_type == TLS_RECORD_APPLICATION
}
/// Check if this is a change cipher spec record (should be skipped)
fn is_change_cipher_spec(&self) -> bool {
self.record_type == TLS_RECORD_CHANGE_CIPHER
}
/// Build header bytes /// Build header bytes
fn to_bytes(&self) -> [u8; 5] { fn to_bytes(&self) -> [u8; 5] {
[ [
@@ -120,25 +157,20 @@ enum TlsReaderState {
header: HeaderBuffer<TLS_HEADER_SIZE>, header: HeaderBuffer<TLS_HEADER_SIZE>,
}, },
/// Reading the TLS record body /// Reading the TLS record body (payload)
ReadingBody { ReadingBody {
/// Parsed record type
record_type: u8, record_type: u8,
/// Total body length
length: usize, length: usize,
/// Buffer for body data
buffer: BytesMut, buffer: BytesMut,
}, },
/// Have decrypted data ready to yield to caller /// Have buffered data ready to yield to caller
Yielding { Yielding {
/// Buffer containing data to yield
buffer: YieldBuffer, buffer: YieldBuffer,
}, },
/// Stream encountered an error and cannot be used /// Stream encountered an error and cannot be used
Poisoned { Poisoned {
/// The error that caused poisoning
error: Option<io::Error>, error: Option<io::Error>,
}, },
} }
@@ -165,12 +197,13 @@ impl StreamState for TlsReaderState {
// ============= FakeTlsReader ============= // ============= FakeTlsReader =============
/// Reader that unwraps TLS 1.3 records with proper state machine /// Reader that unwraps TLS records (FakeTLS).
/// ///
/// This reader handles partial reads correctly by maintaining internal state /// This wrapper is responsible ONLY for TLS record framing and skipping
/// and never losing any data that has been read from upstream. /// non-data records (like CCS). It does not decrypt TLS: payload bytes are passed
/// as-is to upper layers (crypto stream).
/// ///
/// # State Machine /// State machine overview:
/// ///
/// ┌──────────┐ ┌───────────────┐ /// ┌──────────┐ ┌───────────────┐
/// │ Idle │ -----------------> │ ReadingHeader │ /// │ Idle │ -----------------> │ ReadingHeader │
@@ -178,103 +211,69 @@ impl StreamState for TlsReaderState {
/// ▲ │ /// ▲ │
/// │ header complete /// │ header complete
/// │ │ /// │ │
/// │ /// │
/// │ ┌───────────────┐ /// │ ┌───────────────┐
/// │ skip record │ ReadingBody │ /// │ skip record │ ReadingBody │
/// │ <-------- (CCS) -------- │ │ /// │ <-------- (CCS) -------- │ │
/// │ └───────┬───────┘ /// │ └───────┬───────┘
/// │ │ /// │ │
/// │ body complete /// │ body complete
/// │ drained /// │
/// │ <-----------------┐ │ /// │ ┌───────────────┐
/// │ │ ┌───────────────┐ /// │ │ Yielding │
/// │ └----- │ Yielding │
/// │ └───────────────┘ /// │ └───────────────┘
/// │ /// │
/// │ errors / w any state /// │ errors / w any state
/// ///
/// ┌───────────────────────────────────────────────┐ /// ┌───────────────────────────────────────────────┐
/// │ Poisoned │ /// │ Poisoned │
/// └───────────────────────────────────────────────┘ /// └───────────────────────────────────────────────┘
/// ///
/// NOTE: We must correctly handle partial reads from upstream:
/// - do not assume header arrives in one poll
/// - do not assume body arrives in one poll
/// - never lose already-read bytes
pub struct FakeTlsReader<R> { pub struct FakeTlsReader<R> {
/// Upstream reader
upstream: R, upstream: R,
/// Current state
state: TlsReaderState, state: TlsReaderState,
} }
impl<R> FakeTlsReader<R> { impl<R> FakeTlsReader<R> {
/// Create new fake TLS reader
pub fn new(upstream: R) -> Self { pub fn new(upstream: R) -> Self {
Self { Self { upstream, state: TlsReaderState::Idle }
upstream,
state: TlsReaderState::Idle,
}
} }
/// Get reference to upstream pub fn get_ref(&self) -> &R { &self.upstream }
pub fn get_ref(&self) -> &R { pub fn get_mut(&mut self) -> &mut R { &mut self.upstream }
&self.upstream pub fn into_inner(self) -> R { self.upstream }
}
/// Get mutable reference to upstream pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() }
pub fn get_mut(&mut self) -> &mut R { pub fn state_name(&self) -> &'static str { self.state.state_name() }
&mut self.upstream
}
/// Consume and return upstream
pub fn into_inner(self) -> R {
self.upstream
}
/// Check if stream is in poisoned state
pub fn is_poisoned(&self) -> bool {
self.state.is_poisoned()
}
/// Get current state name (for debugging)
pub fn state_name(&self) -> &'static str {
self.state.state_name()
}
/// Transition to poisoned state
fn poison(&mut self, error: io::Error) { fn poison(&mut self, error: io::Error) {
self.state = TlsReaderState::Poisoned { error: Some(error) }; self.state = TlsReaderState::Poisoned { error: Some(error) };
} }
/// Take error from poisoned state
fn take_poison_error(&mut self) -> io::Error { fn take_poison_error(&mut self) -> io::Error {
match &mut self.state { match &mut self.state {
TlsReaderState::Poisoned { error } => { TlsReaderState::Poisoned { error } => error.take().unwrap_or_else(|| {
error.take().unwrap_or_else(|| {
io::Error::new(ErrorKind::Other, "stream previously poisoned") io::Error::new(ErrorKind::Other, "stream previously poisoned")
}) }),
}
_ => io::Error::new(ErrorKind::Other, "stream not poisoned"), _ => io::Error::new(ErrorKind::Other, "stream not poisoned"),
} }
} }
} }
/// Result of polling for header completion
enum HeaderPollResult { enum HeaderPollResult {
/// Need more data
Pending, Pending,
/// EOF at record boundary (clean close)
Eof, Eof,
/// Header complete, parsed successfully
Complete(TlsRecordHeader), Complete(TlsRecordHeader),
/// Error occurred
Error(io::Error), Error(io::Error),
} }
/// Result of polling for body completion
enum BodyPollResult { enum BodyPollResult {
/// Need more data
Pending, Pending,
/// Body complete
Complete(Bytes), Complete(Bytes),
/// Error occurred
Error(io::Error), Error(io::Error),
} }
@@ -291,7 +290,7 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
let state = std::mem::replace(&mut this.state, TlsReaderState::Idle); let state = std::mem::replace(&mut this.state, TlsReaderState::Idle);
match state { match state {
// Poisoned state - return error // Poisoned state: always return the stored error
TlsReaderState::Poisoned { error } => { TlsReaderState::Poisoned { error } => {
this.state = TlsReaderState::Poisoned { error: None }; this.state = TlsReaderState::Poisoned { error: None };
let err = error.unwrap_or_else(|| { let err = error.unwrap_or_else(|| {
@@ -300,20 +299,18 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
return Poll::Ready(Err(err)); return Poll::Ready(Err(err));
} }
// Have buffered data to yield // Yield buffered plaintext to caller
TlsReaderState::Yielding { mut buffer } => { TlsReaderState::Yielding { mut buffer } => {
if buf.remaining() == 0 { if buf.remaining() == 0 {
this.state = TlsReaderState::Yielding { buffer }; this.state = TlsReaderState::Yielding { buffer };
return Poll::Ready(Ok(())); return Poll::Ready(Ok(()));
} }
// Copy as much as possible to output
let to_copy = buffer.remaining().min(buf.remaining()); let to_copy = buffer.remaining().min(buf.remaining());
let dst = buf.initialize_unfilled_to(to_copy); let dst = buf.initialize_unfilled_to(to_copy);
let copied = buffer.copy_to(dst); let copied = buffer.copy_to(dst);
buf.advance(copied); buf.advance(copied);
// If buffer is drained, transition to Idle
if buffer.is_empty() { if buffer.is_empty() {
this.state = TlsReaderState::Idle; this.state = TlsReaderState::Idle;
} else { } else {
@@ -323,23 +320,21 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
return Poll::Ready(Ok(())); return Poll::Ready(Ok(()));
} }
// Ready to read a new TLS record // Start reading new record
TlsReaderState::Idle => { TlsReaderState::Idle => {
if buf.remaining() == 0 { if buf.remaining() == 0 {
this.state = TlsReaderState::Idle; this.state = TlsReaderState::Idle;
return Poll::Ready(Ok(())); return Poll::Ready(Ok(()));
} }
// Start reading header
this.state = TlsReaderState::ReadingHeader { this.state = TlsReaderState::ReadingHeader {
header: HeaderBuffer::new(), header: HeaderBuffer::new(),
}; };
// Continue to ReadingHeader // loop continues and will handle ReadingHeader
} }
// Reading TLS record header // Read TLS header (5 bytes)
TlsReaderState::ReadingHeader { mut header } => { TlsReaderState::ReadingHeader { mut header } => {
// Poll to fill header
let result = poll_read_header(&mut this.upstream, cx, &mut header); let result = poll_read_header(&mut this.upstream, cx, &mut header);
match result { match result {
@@ -348,6 +343,7 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
return Poll::Pending; return Poll::Pending;
} }
HeaderPollResult::Eof => { HeaderPollResult::Eof => {
// Clean EOF at record boundary
this.state = TlsReaderState::Idle; this.state = TlsReaderState::Idle;
return Poll::Ready(Ok(())); return Poll::Ready(Ok(()));
} }
@@ -356,15 +352,12 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
return Poll::Ready(Err(e)); return Poll::Ready(Err(e));
} }
HeaderPollResult::Complete(parsed) => { HeaderPollResult::Complete(parsed) => {
// Validate header
if let Err(e) = parsed.validate() { if let Err(e) = parsed.validate() {
this.poison(Error::new(e.kind(), e.to_string())); this.poison(Error::new(e.kind(), e.to_string()));
return Poll::Ready(Err(e)); return Poll::Ready(Err(e));
} }
let length = parsed.length as usize; let length = parsed.length as usize;
// Transition to reading body
this.state = TlsReaderState::ReadingBody { this.state = TlsReaderState::ReadingBody {
record_type: parsed.record_type, record_type: parsed.record_type,
length, length,
@@ -374,7 +367,7 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
} }
} }
// Reading TLS record body // Read TLS payload
TlsReaderState::ReadingBody { record_type, length, mut buffer } => { TlsReaderState::ReadingBody { record_type, length, mut buffer } => {
let result = poll_read_body(&mut this.upstream, cx, &mut buffer, length); let result = poll_read_body(&mut this.upstream, cx, &mut buffer, length);
@@ -388,15 +381,15 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
return Poll::Ready(Err(e)); return Poll::Ready(Err(e));
} }
BodyPollResult::Complete(data) => { BodyPollResult::Complete(data) => {
// Handle different record types
match record_type { match record_type {
TLS_RECORD_CHANGE_CIPHER => { TLS_RECORD_CHANGE_CIPHER => {
// Skip Change Cipher Spec, read next record // CCS is expected in some clients, ignore it.
this.state = TlsReaderState::Idle; this.state = TlsReaderState::Idle;
continue; continue;
} }
TLS_RECORD_APPLICATION => { TLS_RECORD_APPLICATION => {
// Application data - yield to caller // This is what we actually want.
if data.is_empty() { if data.is_empty() {
this.state = TlsReaderState::Idle; this.state = TlsReaderState::Idle;
continue; continue;
@@ -405,25 +398,26 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
this.state = TlsReaderState::Yielding { this.state = TlsReaderState::Yielding {
buffer: YieldBuffer::new(data), buffer: YieldBuffer::new(data),
}; };
// Continue to yield // loop continues and will yield immediately
} }
TLS_RECORD_ALERT => { TLS_RECORD_ALERT => {
// TLS Alert - treat as EOF // Treat TLS alert as EOF-like termination.
this.state = TlsReaderState::Idle; this.state = TlsReaderState::Idle;
return Poll::Ready(Ok(())); return Poll::Ready(Ok(()));
} }
TLS_RECORD_HANDSHAKE => { TLS_RECORD_HANDSHAKE => {
let err = Error::new( // After FakeTLS handshake is done, we do not expect any Handshake records.
ErrorKind::InvalidData, let err = Error::new(ErrorKind::InvalidData, "unexpected TLS handshake record");
"unexpected TLS handshake record"
);
this.poison(Error::new(err.kind(), err.to_string())); this.poison(Error::new(err.kind(), err.to_string()));
return Poll::Ready(Err(err)); return Poll::Ready(Err(err));
} }
_ => { _ => {
let err = Error::new( let err = Error::new(
ErrorKind::InvalidData, ErrorKind::InvalidData,
format!("unknown TLS record type: 0x{:02x}", record_type) format!("unknown TLS record type: 0x{:02x}", record_type),
); );
this.poison(Error::new(err.kind(), err.to_string())); this.poison(Error::new(err.kind(), err.to_string()));
return Poll::Ready(Err(err)); return Poll::Ready(Err(err));
@@ -459,8 +453,10 @@ fn poll_read_header<R: AsyncRead + Unpin>(
} else { } else {
return HeaderPollResult::Error(Error::new( return HeaderPollResult::Error(Error::new(
ErrorKind::UnexpectedEof, ErrorKind::UnexpectedEof,
format!("unexpected EOF in TLS header (got {} of 5 bytes)", format!(
header.as_slice().len()) "unexpected EOF in TLS header (got {} of 5 bytes)",
header.as_slice().len()
),
)); ));
} }
} }
@@ -469,14 +465,10 @@ fn poll_read_header<R: AsyncRead + Unpin>(
} }
} }
// Parse header
let header_bytes = *header.as_array(); let header_bytes = *header.as_array();
match TlsRecordHeader::parse(&header_bytes) { match TlsRecordHeader::parse(&header_bytes) {
Some(h) => HeaderPollResult::Complete(h), Some(h) => HeaderPollResult::Complete(h),
None => HeaderPollResult::Error(Error::new( None => HeaderPollResult::Error(Error::new(ErrorKind::InvalidData, "failed to parse TLS header")),
ErrorKind::InvalidData,
"failed to parse TLS header"
)),
} }
} }
@@ -487,10 +479,12 @@ fn poll_read_body<R: AsyncRead + Unpin>(
buffer: &mut BytesMut, buffer: &mut BytesMut,
target_len: usize, target_len: usize,
) -> BodyPollResult { ) -> BodyPollResult {
// NOTE: This implementation uses a temporary Vec to avoid tricky borrow/lifetime
// issues with BytesMut spare capacity and ReadBuf across polls.
// It's safe and correct; optimization is possible if needed.
while buffer.len() < target_len { while buffer.len() < target_len {
let remaining = target_len - buffer.len(); let remaining = target_len - buffer.len();
// Read into a temporary buffer
let mut temp = vec![0u8; remaining.min(8192)]; let mut temp = vec![0u8; remaining.min(8192)];
let mut read_buf = ReadBuf::new(&mut temp); let mut read_buf = ReadBuf::new(&mut temp);
@@ -502,8 +496,11 @@ fn poll_read_body<R: AsyncRead + Unpin>(
if n == 0 { if n == 0 {
return BodyPollResult::Error(Error::new( return BodyPollResult::Error(Error::new(
ErrorKind::UnexpectedEof, ErrorKind::UnexpectedEof,
format!("unexpected EOF in TLS body (got {} of {} bytes)", format!(
buffer.len(), target_len) "unexpected EOF in TLS body (got {} of {} bytes)",
buffer.len(),
target_len
),
)); ));
} }
buffer.extend_from_slice(&temp[..n]); buffer.extend_from_slice(&temp[..n]);
@@ -515,10 +512,9 @@ fn poll_read_body<R: AsyncRead + Unpin>(
} }
impl<R: AsyncRead + Unpin> FakeTlsReader<R> { impl<R: AsyncRead + Unpin> FakeTlsReader<R> {
/// Read exactly n bytes through TLS layer /// Read exactly n bytes through TLS layer.
/// ///
/// This is a convenience method that accumulates data across /// This accumulates data across multiple TLS ApplicationData records.
/// multiple TLS records until exactly n bytes are available.
pub async fn read_exact(&mut self, n: usize) -> Result<Bytes> { pub async fn read_exact(&mut self, n: usize) -> Result<Bytes> {
if self.is_poisoned() { if self.is_poisoned() {
return Err(self.take_poison_error()); return Err(self.take_poison_error());
@@ -533,7 +529,7 @@ impl<R: AsyncRead + Unpin> FakeTlsReader<R> {
if read == 0 { if read == 0 {
return Err(Error::new( return Err(Error::new(
ErrorKind::UnexpectedEof, ErrorKind::UnexpectedEof,
format!("expected {} bytes, got {}", n, result.len()) format!("expected {} bytes, got {}", n, result.len()),
)); ));
} }
@@ -546,23 +542,19 @@ impl<R: AsyncRead + Unpin> FakeTlsReader<R> {
// ============= FakeTlsWriter State ============= // ============= FakeTlsWriter State =============
/// State machine states for FakeTlsWriter
#[derive(Debug)] #[derive(Debug)]
enum TlsWriterState { enum TlsWriterState {
/// Ready to accept new data /// Ready to accept new data
Idle, Idle,
/// Writing a complete TLS record /// Writing a complete TLS record (header + body), possibly partially
WritingRecord { WritingRecord {
/// Complete record (header + body) to write
record: WriteBuffer, record: WriteBuffer,
/// Original payload size (for return value calculation)
payload_size: usize, payload_size: usize,
}, },
/// Stream encountered an error and cannot be used /// Stream encountered an error and cannot be used
Poisoned { Poisoned {
/// The error that caused poisoning
error: Option<io::Error>, error: Option<io::Error>,
}, },
} }
@@ -587,94 +579,46 @@ impl StreamState for TlsWriterState {
// ============= FakeTlsWriter ============= // ============= FakeTlsWriter =============
/// Writer that wraps data in TLS 1.3 records with proper state machine /// Writer that wraps bytes into TLS 1.3 Application Data records.
/// ///
/// This writer handles partial writes correctly by: /// We chunk outgoing data into records of <= 16384 payload bytes (MAX_TLS_PAYLOAD).
/// - Building complete TLS records before writing /// We do not try to mimic AEAD overhead on the wire; Telegram clients accept it.
/// - Maintaining internal state for partial record writes /// If you want to be more camouflage-accurate later, you could add optional padding
/// - Never splitting a record mid-write to upstream /// to produce records sized closer to MAX_TLS_CHUNK_SIZE.
///
/// # State Machine
///
/// ┌──────────┐ write ┌─────────────────┐
/// │ Idle │ -------------> │ WritingRecord │
/// │ │ <------------- │ │
/// └──────────┘ complete └─────────────────┘
/// │ │
/// │ < errors > │
/// │ │
/// ┌─────────────────────────────────────────────┐
/// │ Poisoned │
/// └─────────────────────────────────────────────┘
///
/// # Record Formation
///
/// Data is chunked into records of at most MAX_TLS_PAYLOAD bytes.
/// Each record has a 5-byte header prepended.
pub struct FakeTlsWriter<W> { pub struct FakeTlsWriter<W> {
/// Upstream writer
upstream: W, upstream: W,
/// Current state
state: TlsWriterState, state: TlsWriterState,
} }
impl<W> FakeTlsWriter<W> { impl<W> FakeTlsWriter<W> {
/// Create new fake TLS writer
pub fn new(upstream: W) -> Self { pub fn new(upstream: W) -> Self {
Self { Self { upstream, state: TlsWriterState::Idle }
upstream,
state: TlsWriterState::Idle,
}
} }
/// Get reference to upstream pub fn get_ref(&self) -> &W { &self.upstream }
pub fn get_ref(&self) -> &W { pub fn get_mut(&mut self) -> &mut W { &mut self.upstream }
&self.upstream pub fn into_inner(self) -> W { self.upstream }
}
/// Get mutable reference to upstream pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() }
pub fn get_mut(&mut self) -> &mut W { pub fn state_name(&self) -> &'static str { self.state.state_name() }
&mut self.upstream
}
/// Consume and return upstream
pub fn into_inner(self) -> W {
self.upstream
}
/// Check if stream is in poisoned state
pub fn is_poisoned(&self) -> bool {
self.state.is_poisoned()
}
/// Get current state name (for debugging)
pub fn state_name(&self) -> &'static str {
self.state.state_name()
}
/// Check if there's a pending record to write
pub fn has_pending(&self) -> bool { pub fn has_pending(&self) -> bool {
matches!(&self.state, TlsWriterState::WritingRecord { record, .. } if !record.is_empty()) matches!(&self.state, TlsWriterState::WritingRecord { record, .. } if !record.is_empty())
} }
/// Transition to poisoned state
fn poison(&mut self, error: io::Error) { fn poison(&mut self, error: io::Error) {
self.state = TlsWriterState::Poisoned { error: Some(error) }; self.state = TlsWriterState::Poisoned { error: Some(error) };
} }
/// Take error from poisoned state
fn take_poison_error(&mut self) -> io::Error { fn take_poison_error(&mut self) -> io::Error {
match &mut self.state { match &mut self.state {
TlsWriterState::Poisoned { error } => { TlsWriterState::Poisoned { error } => error.take().unwrap_or_else(|| {
error.take().unwrap_or_else(|| {
io::Error::new(ErrorKind::Other, "stream previously poisoned") io::Error::new(ErrorKind::Other, "stream previously poisoned")
}) }),
}
_ => io::Error::new(ErrorKind::Other, "stream not poisoned"), _ => io::Error::new(ErrorKind::Other, "stream not poisoned"),
} }
} }
/// Build a TLS Application Data record
fn build_record(data: &[u8]) -> BytesMut { fn build_record(data: &[u8]) -> BytesMut {
let header = TlsRecordHeader { let header = TlsRecordHeader {
record_type: TLS_RECORD_APPLICATION, record_type: TLS_RECORD_APPLICATION,
@@ -689,18 +633,13 @@ impl<W> FakeTlsWriter<W> {
} }
} }
/// Result of flushing pending record
enum FlushResult { enum FlushResult {
/// All data flushed, returns payload size
Complete(usize), Complete(usize),
/// Need to wait for upstream
Pending, Pending,
/// Error occurred
Error(io::Error), Error(io::Error),
} }
impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> { impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
/// Try to flush pending record to upstream (standalone logic)
fn poll_flush_record_inner( fn poll_flush_record_inner(
upstream: &mut W, upstream: &mut W,
cx: &mut Context<'_>, cx: &mut Context<'_>,
@@ -710,19 +649,14 @@ impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
let data = record.pending(); let data = record.pending();
match Pin::new(&mut *upstream).poll_write(cx, data) { match Pin::new(&mut *upstream).poll_write(cx, data) {
Poll::Pending => return FlushResult::Pending, Poll::Pending => return FlushResult::Pending,
Poll::Ready(Err(e)) => return FlushResult::Error(e), Poll::Ready(Err(e)) => return FlushResult::Error(e),
Poll::Ready(Ok(0)) => { Poll::Ready(Ok(0)) => {
return FlushResult::Error(Error::new( return FlushResult::Error(Error::new(
ErrorKind::WriteZero, ErrorKind::WriteZero,
"upstream returned 0 bytes written" "upstream returned 0 bytes written",
)); ));
} }
Poll::Ready(Ok(n)) => record.advance(n),
Poll::Ready(Ok(n)) => {
record.advance(n);
}
} }
} }
@@ -738,7 +672,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
) -> Poll<Result<usize>> { ) -> Poll<Result<usize>> {
let this = self.get_mut(); let this = self.get_mut();
// Take ownership of state // Take ownership of state to avoid borrow conflicts.
let state = std::mem::replace(&mut this.state, TlsWriterState::Idle); let state = std::mem::replace(&mut this.state, TlsWriterState::Idle);
match state { match state {
@@ -751,7 +685,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
} }
TlsWriterState::WritingRecord { mut record, payload_size } => { TlsWriterState::WritingRecord { mut record, payload_size } => {
// Continue flushing existing record // Finish writing previous record before accepting new bytes.
match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) { match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) {
FlushResult::Pending => { FlushResult::Pending => {
this.state = TlsWriterState::WritingRecord { record, payload_size }; this.state = TlsWriterState::WritingRecord { record, payload_size };
@@ -763,7 +697,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
} }
FlushResult::Complete(_) => { FlushResult::Complete(_) => {
this.state = TlsWriterState::Idle; this.state = TlsWriterState::Idle;
// Fall through to handle new write // continue to accept new buf below
} }
} }
} }
@@ -782,19 +716,18 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
let chunk_size = buf.len().min(MAX_TLS_PAYLOAD); let chunk_size = buf.len().min(MAX_TLS_PAYLOAD);
let chunk = &buf[..chunk_size]; let chunk = &buf[..chunk_size];
// Build the complete record // Build the complete record (header + payload)
let record_data = Self::build_record(chunk); let record_data = Self::build_record(chunk);
// Try to write directly first
match Pin::new(&mut this.upstream).poll_write(cx, &record_data) { match Pin::new(&mut this.upstream).poll_write(cx, &record_data) {
Poll::Ready(Ok(n)) if n == record_data.len() => { Poll::Ready(Ok(n)) if n == record_data.len() => {
// Complete record written
Poll::Ready(Ok(chunk_size)) Poll::Ready(Ok(chunk_size))
} }
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(n)) => {
// Partial write - buffer the rest // Partial write of the record: store remainder.
let mut write_buffer = WriteBuffer::with_max_size(MAX_PENDING_WRITE); let mut write_buffer = WriteBuffer::with_max_size(MAX_PENDING_WRITE);
// record_data length is <= 16389, fits MAX_PENDING_WRITE
let _ = write_buffer.extend(&record_data[n..]); let _ = write_buffer.extend(&record_data[n..]);
this.state = TlsWriterState::WritingRecord { this.state = TlsWriterState::WritingRecord {
@@ -802,7 +735,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
payload_size: chunk_size, payload_size: chunk_size,
}; };
// We've accepted chunk_size bytes from caller // We have accepted chunk_size bytes from caller.
Poll::Ready(Ok(chunk_size)) Poll::Ready(Ok(chunk_size))
} }
@@ -812,7 +745,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
} }
Poll::Pending => { Poll::Pending => {
// Buffer the entire record // Buffer entire record and report success for this chunk.
let mut write_buffer = WriteBuffer::with_max_size(MAX_PENDING_WRITE); let mut write_buffer = WriteBuffer::with_max_size(MAX_PENDING_WRITE);
let _ = write_buffer.extend(&record_data); let _ = write_buffer.extend(&record_data);
@@ -821,10 +754,9 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
payload_size: chunk_size, payload_size: chunk_size,
}; };
// Wake to try again // Wake to retry flushing soon.
cx.waker().wake_by_ref(); cx.waker().wake_by_ref();
// We've accepted chunk_size bytes from caller
Poll::Ready(Ok(chunk_size)) Poll::Ready(Ok(chunk_size))
} }
} }
@@ -833,7 +765,6 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
let this = self.get_mut(); let this = self.get_mut();
// Take ownership of state
let state = std::mem::replace(&mut this.state, TlsWriterState::Idle); let state = std::mem::replace(&mut this.state, TlsWriterState::Idle);
match state { match state {
@@ -866,48 +797,33 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
} }
} }
// Flush upstream
Pin::new(&mut this.upstream).poll_flush(cx) Pin::new(&mut this.upstream).poll_flush(cx)
} }
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> { fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
let this = self.get_mut(); let this = self.get_mut();
// Take ownership of state
let state = std::mem::replace(&mut this.state, TlsWriterState::Idle); let state = std::mem::replace(&mut this.state, TlsWriterState::Idle);
match state { match state {
TlsWriterState::WritingRecord { mut record, payload_size } => { TlsWriterState::WritingRecord { mut record, payload_size: _ } => {
// Try to flush pending (best effort) // Best-effort flush (do not block shutdown forever).
match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) { let _ = Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record);
FlushResult::Pending => {
// Can't complete flush, continue with shutdown anyway
this.state = TlsWriterState::Idle; this.state = TlsWriterState::Idle;
} }
FlushResult::Error(_) => {
// Ignore errors during shutdown
this.state = TlsWriterState::Idle;
}
FlushResult::Complete(_) => {
this.state = TlsWriterState::Idle;
}
}
}
_ => { _ => {
this.state = TlsWriterState::Idle; this.state = TlsWriterState::Idle;
} }
} }
// Shutdown upstream
Pin::new(&mut this.upstream).poll_shutdown(cx) Pin::new(&mut this.upstream).poll_shutdown(cx)
} }
} }
impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> { impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
/// Write all data wrapped in TLS records (async method) /// Write all data wrapped in TLS records.
/// ///
/// This convenience method handles chunking large data into /// Convenience method that chunks into <= 16384 records.
/// multiple TLS records automatically.
pub async fn write_all_tls(&mut self, data: &[u8]) -> Result<()> { pub async fn write_all_tls(&mut self, data: &[u8]) -> Result<()> {
let mut written = 0; let mut written = 0;
while written < data.len() { while written < data.len() {

View File

@@ -30,20 +30,13 @@ pub fn configure_tcp_socket(
socket.set_tcp_keepalive(&keepalive)?; socket.set_tcp_keepalive(&keepalive)?;
} }
// Set buffer sizes // CHANGED: Removed manual buffer size setting (was 256KB).
set_buffer_sizes(&socket, 65536, 65536)?; // Allowing the OS kernel to handle TCP window scaling (Autotuning) is critical
// for mobile clients to avoid bufferbloat and stalled connections during uploads.
Ok(()) Ok(())
} }
/// Set socket buffer sizes
fn set_buffer_sizes(socket: &socket2::SockRef, recv: usize, send: usize) -> Result<()> {
// These may fail on some systems, so we ignore errors
let _ = socket.set_recv_buffer_size(recv);
let _ = socket.set_send_buffer_size(send);
Ok(())
}
/// Configure socket for accepting client connections /// Configure socket for accepting client connections
pub fn configure_client_socket( pub fn configure_client_socket(
stream: &TcpStream, stream: &TcpStream,
@@ -65,6 +58,8 @@ pub fn configure_client_socket(
socket.set_tcp_keepalive(&keepalive)?; socket.set_tcp_keepalive(&keepalive)?;
// Set TCP user timeout (Linux only) // Set TCP user timeout (Linux only)
// NOTE: iOS does not support TCP_USER_TIMEOUT - application-level timeout
// is implemented in relay_bidirectional instead
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
{ {
use std::os::unix::io::AsRawFd; use std::os::unix::io::AsRawFd;