@@ -163,9 +163,12 @@ fn default_port() -> u16 { 443 }
|
|||||||
fn default_tls_domain() -> String { "www.google.com".to_string() }
|
fn default_tls_domain() -> String { "www.google.com".to_string() }
|
||||||
fn default_mask_port() -> u16 { 443 }
|
fn default_mask_port() -> u16 { 443 }
|
||||||
fn default_replay_check_len() -> usize { 65536 }
|
fn default_replay_check_len() -> usize { 65536 }
|
||||||
fn default_handshake_timeout() -> u64 { 10 }
|
// CHANGED: Increased handshake timeout for bad mobile networks
|
||||||
|
fn default_handshake_timeout() -> u64 { 15 }
|
||||||
fn default_connect_timeout() -> u64 { 10 }
|
fn default_connect_timeout() -> u64 { 10 }
|
||||||
fn default_keepalive() -> u64 { 600 }
|
// CHANGED: Reduced keepalive from 600s to 60s.
|
||||||
|
// Mobile NATs often drop idle connections after 60-120s.
|
||||||
|
fn default_keepalive() -> u64 { 60 }
|
||||||
fn default_ack_timeout() -> u64 { 300 }
|
fn default_ack_timeout() -> u64 { 300 }
|
||||||
fn default_listen_addr() -> String { "0.0.0.0".to_string() }
|
fn default_listen_addr() -> String { "0.0.0.0".to_string() }
|
||||||
fn default_fake_cert_len() -> usize { 2048 }
|
fn default_fake_cert_len() -> usize { 2048 }
|
||||||
|
|||||||
13
src/main.rs
13
src/main.rs
@@ -20,7 +20,7 @@ mod util;
|
|||||||
|
|
||||||
use crate::config::ProxyConfig;
|
use crate::config::ProxyConfig;
|
||||||
use crate::proxy::ClientHandler;
|
use crate::proxy::ClientHandler;
|
||||||
use crate::stats::Stats;
|
use crate::stats::{Stats, ReplayChecker};
|
||||||
use crate::transport::{create_listener, ListenOptions, UpstreamManager};
|
use crate::transport::{create_listener, ListenOptions, UpstreamManager};
|
||||||
use crate::util::ip::detect_ip;
|
use crate::util::ip::detect_ip;
|
||||||
|
|
||||||
@@ -55,6 +55,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
let config = Arc::new(config);
|
let config = Arc::new(config);
|
||||||
let stats = Arc::new(Stats::new());
|
let stats = Arc::new(Stats::new());
|
||||||
|
|
||||||
|
// CHANGED: Initialize global ReplayChecker here instead of per-connection
|
||||||
|
let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len));
|
||||||
|
|
||||||
// Initialize Upstream Manager
|
// Initialize Upstream Manager
|
||||||
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
|
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
|
||||||
|
|
||||||
@@ -145,13 +148,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Accept loop
|
// Accept loop
|
||||||
// For simplicity in this slice, we just spawn a task for each listener
|
|
||||||
// In a real high-perf scenario, we might want a more complex accept loop
|
|
||||||
|
|
||||||
for listener in listeners {
|
for listener in listeners {
|
||||||
let config = config.clone();
|
let config = config.clone();
|
||||||
let stats = stats.clone();
|
let stats = stats.clone();
|
||||||
let upstream_manager = upstream_manager.clone();
|
let upstream_manager = upstream_manager.clone();
|
||||||
|
let replay_checker = replay_checker.clone();
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
loop {
|
loop {
|
||||||
@@ -160,6 +161,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
let config = config.clone();
|
let config = config.clone();
|
||||||
let stats = stats.clone();
|
let stats = stats.clone();
|
||||||
let upstream_manager = upstream_manager.clone();
|
let upstream_manager = upstream_manager.clone();
|
||||||
|
let replay_checker = replay_checker.clone();
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(e) = ClientHandler::new(
|
if let Err(e) = ClientHandler::new(
|
||||||
@@ -167,7 +169,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
peer_addr,
|
peer_addr,
|
||||||
config,
|
config,
|
||||||
stats,
|
stats,
|
||||||
upstream_manager
|
upstream_manager,
|
||||||
|
replay_checker // Pass global checker
|
||||||
).run().await {
|
).run().await {
|
||||||
// Log only relevant errors
|
// Log only relevant errors
|
||||||
// debug!("Connection error: {}", e);
|
// debug!("Connection error: {}", e);
|
||||||
|
|||||||
@@ -167,7 +167,10 @@ pub const DEFAULT_ACK_TIMEOUT_SECS: u64 = 300;
|
|||||||
// ============= Buffer Sizes =============
|
// ============= Buffer Sizes =============
|
||||||
|
|
||||||
/// Default buffer size
|
/// Default buffer size
|
||||||
pub const DEFAULT_BUFFER_SIZE: usize = 65536;
|
/// CHANGED: Reduced from 64KB to 16KB to match TLS record size and align with
|
||||||
|
/// the new buffering strategy for better iOS upload performance.
|
||||||
|
pub const DEFAULT_BUFFER_SIZE: usize = 16384;
|
||||||
|
|
||||||
/// Small buffer size for bad client handling
|
/// Small buffer size for bad client handling
|
||||||
pub const SMALL_BUFFER_SIZE: usize = 8192;
|
pub const SMALL_BUFFER_SIZE: usize = 8192;
|
||||||
|
|
||||||
|
|||||||
@@ -45,11 +45,10 @@ impl ClientHandler {
|
|||||||
config: Arc<ProxyConfig>,
|
config: Arc<ProxyConfig>,
|
||||||
stats: Arc<Stats>,
|
stats: Arc<Stats>,
|
||||||
upstream_manager: Arc<UpstreamManager>,
|
upstream_manager: Arc<UpstreamManager>,
|
||||||
|
replay_checker: Arc<ReplayChecker>, // CHANGED: Accept global checker
|
||||||
) -> RunningClientHandler {
|
) -> RunningClientHandler {
|
||||||
// Note: ReplayChecker should be shared globally for proper replay protection
|
// CHANGED: Removed local creation of ReplayChecker.
|
||||||
// Creating it per-connection disables replay protection across connections
|
// It is now passed from main.rs to ensure global replay protection.
|
||||||
// TODO: Pass Arc<ReplayChecker> from main.rs
|
|
||||||
let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len));
|
|
||||||
|
|
||||||
RunningClientHandler {
|
RunningClientHandler {
|
||||||
stream,
|
stream,
|
||||||
|
|||||||
@@ -1,13 +1,21 @@
|
|||||||
//! Bidirectional Relay
|
//! Bidirectional Relay
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
||||||
use tracing::{debug, trace, warn};
|
use tokio::time::Instant;
|
||||||
|
use tracing::{debug, trace, warn, info};
|
||||||
use crate::error::Result;
|
use crate::error::Result;
|
||||||
use crate::stats::Stats;
|
use crate::stats::Stats;
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
|
||||||
const BUFFER_SIZE: usize = 65536;
|
// CHANGED: Reduced from 128KB to 16KB to match TLS record size and prevent bufferbloat.
|
||||||
|
// This is critical for iOS clients to maintain proper TCP flow control during uploads.
|
||||||
|
const BUFFER_SIZE: usize = 16384;
|
||||||
|
|
||||||
|
// Activity timeout for iOS compatibility (30 minutes)
|
||||||
|
// iOS does not support TCP_USER_TIMEOUT, so we implement application-level timeout
|
||||||
|
const ACTIVITY_TIMEOUT_SECS: u64 = 1800;
|
||||||
|
|
||||||
/// Relay data bidirectionally between client and server
|
/// Relay data bidirectionally between client and server
|
||||||
pub async fn relay_bidirectional<CR, CW, SR, SW>(
|
pub async fn relay_bidirectional<CR, CW, SR, SW>(
|
||||||
@@ -36,15 +44,40 @@ where
|
|||||||
let c2s_bytes_clone = Arc::clone(&c2s_bytes);
|
let c2s_bytes_clone = Arc::clone(&c2s_bytes);
|
||||||
let s2c_bytes_clone = Arc::clone(&s2c_bytes);
|
let s2c_bytes_clone = Arc::clone(&s2c_bytes);
|
||||||
|
|
||||||
// Client -> Server task
|
// Activity timeout for iOS compatibility
|
||||||
|
let activity_timeout = Duration::from_secs(ACTIVITY_TIMEOUT_SECS);
|
||||||
|
|
||||||
|
// Client -> Server task with activity timeout
|
||||||
let c2s = tokio::spawn(async move {
|
let c2s = tokio::spawn(async move {
|
||||||
let mut buf = vec![0u8; BUFFER_SIZE];
|
let mut buf = vec![0u8; BUFFER_SIZE];
|
||||||
let mut total_bytes = 0u64;
|
let mut total_bytes = 0u64;
|
||||||
let mut msg_count = 0u64;
|
let mut msg_count = 0u64;
|
||||||
|
let mut last_activity = Instant::now();
|
||||||
|
let mut last_log = Instant::now();
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
match client_reader.read(&mut buf).await {
|
// Read with timeout to prevent infinite hang on iOS
|
||||||
Ok(0) => {
|
let read_result = tokio::time::timeout(
|
||||||
|
activity_timeout,
|
||||||
|
client_reader.read(&mut buf)
|
||||||
|
).await;
|
||||||
|
|
||||||
|
match read_result {
|
||||||
|
// Timeout - no activity for too long
|
||||||
|
Err(_) => {
|
||||||
|
warn!(
|
||||||
|
user = %user_c2s,
|
||||||
|
total_bytes = total_bytes,
|
||||||
|
msgs = msg_count,
|
||||||
|
idle_secs = last_activity.elapsed().as_secs(),
|
||||||
|
"Activity timeout (C->S) - no data received"
|
||||||
|
);
|
||||||
|
let _ = server_writer.shutdown().await;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read successful
|
||||||
|
Ok(Ok(0)) => {
|
||||||
debug!(
|
debug!(
|
||||||
user = %user_c2s,
|
user = %user_c2s,
|
||||||
total_bytes = total_bytes,
|
total_bytes = total_bytes,
|
||||||
@@ -54,9 +87,11 @@ where
|
|||||||
let _ = server_writer.shutdown().await;
|
let _ = server_writer.shutdown().await;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
Ok(n) => {
|
|
||||||
|
Ok(Ok(n)) => {
|
||||||
total_bytes += n as u64;
|
total_bytes += n as u64;
|
||||||
msg_count += 1;
|
msg_count += 1;
|
||||||
|
last_activity = Instant::now();
|
||||||
c2s_bytes_clone.store(total_bytes, Ordering::Relaxed);
|
c2s_bytes_clone.store(total_bytes, Ordering::Relaxed);
|
||||||
|
|
||||||
stats_c2s.add_user_octets_from(&user_c2s, n as u64);
|
stats_c2s.add_user_octets_from(&user_c2s, n as u64);
|
||||||
@@ -70,6 +105,19 @@ where
|
|||||||
"C->S data"
|
"C->S data"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Log activity every 10 seconds for large transfers
|
||||||
|
if last_log.elapsed() > Duration::from_secs(10) {
|
||||||
|
let rate = total_bytes as f64 / last_log.elapsed().as_secs_f64();
|
||||||
|
info!(
|
||||||
|
user = %user_c2s,
|
||||||
|
total_bytes = total_bytes,
|
||||||
|
msgs = msg_count,
|
||||||
|
rate_kbps = (rate / 1024.0) as u64,
|
||||||
|
"C->S transfer in progress"
|
||||||
|
);
|
||||||
|
last_log = Instant::now();
|
||||||
|
}
|
||||||
|
|
||||||
if let Err(e) = server_writer.write_all(&buf[..n]).await {
|
if let Err(e) = server_writer.write_all(&buf[..n]).await {
|
||||||
debug!(user = %user_c2s, error = %e, "Failed to write to server");
|
debug!(user = %user_c2s, error = %e, "Failed to write to server");
|
||||||
break;
|
break;
|
||||||
@@ -79,7 +127,8 @@ where
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
|
||||||
|
Ok(Err(e)) => {
|
||||||
debug!(user = %user_c2s, error = %e, total_bytes = total_bytes, "Client read error");
|
debug!(user = %user_c2s, error = %e, total_bytes = total_bytes, "Client read error");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -87,15 +136,37 @@ where
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// Server -> Client task
|
// Server -> Client task with activity timeout
|
||||||
let s2c = tokio::spawn(async move {
|
let s2c = tokio::spawn(async move {
|
||||||
let mut buf = vec![0u8; BUFFER_SIZE];
|
let mut buf = vec![0u8; BUFFER_SIZE];
|
||||||
let mut total_bytes = 0u64;
|
let mut total_bytes = 0u64;
|
||||||
let mut msg_count = 0u64;
|
let mut msg_count = 0u64;
|
||||||
|
let mut last_activity = Instant::now();
|
||||||
|
let mut last_log = Instant::now();
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
match server_reader.read(&mut buf).await {
|
// Read with timeout to prevent infinite hang on iOS
|
||||||
Ok(0) => {
|
let read_result = tokio::time::timeout(
|
||||||
|
activity_timeout,
|
||||||
|
server_reader.read(&mut buf)
|
||||||
|
).await;
|
||||||
|
|
||||||
|
match read_result {
|
||||||
|
// Timeout - no activity for too long
|
||||||
|
Err(_) => {
|
||||||
|
warn!(
|
||||||
|
user = %user_s2c,
|
||||||
|
total_bytes = total_bytes,
|
||||||
|
msgs = msg_count,
|
||||||
|
idle_secs = last_activity.elapsed().as_secs(),
|
||||||
|
"Activity timeout (S->C) - no data received"
|
||||||
|
);
|
||||||
|
let _ = client_writer.shutdown().await;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read successful
|
||||||
|
Ok(Ok(0)) => {
|
||||||
debug!(
|
debug!(
|
||||||
user = %user_s2c,
|
user = %user_s2c,
|
||||||
total_bytes = total_bytes,
|
total_bytes = total_bytes,
|
||||||
@@ -105,9 +176,11 @@ where
|
|||||||
let _ = client_writer.shutdown().await;
|
let _ = client_writer.shutdown().await;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
Ok(n) => {
|
|
||||||
|
Ok(Ok(n)) => {
|
||||||
total_bytes += n as u64;
|
total_bytes += n as u64;
|
||||||
msg_count += 1;
|
msg_count += 1;
|
||||||
|
last_activity = Instant::now();
|
||||||
s2c_bytes_clone.store(total_bytes, Ordering::Relaxed);
|
s2c_bytes_clone.store(total_bytes, Ordering::Relaxed);
|
||||||
|
|
||||||
stats_s2c.add_user_octets_to(&user_s2c, n as u64);
|
stats_s2c.add_user_octets_to(&user_s2c, n as u64);
|
||||||
@@ -121,6 +194,19 @@ where
|
|||||||
"S->C data"
|
"S->C data"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Log activity every 10 seconds for large transfers
|
||||||
|
if last_log.elapsed() > Duration::from_secs(10) {
|
||||||
|
let rate = total_bytes as f64 / last_log.elapsed().as_secs_f64();
|
||||||
|
info!(
|
||||||
|
user = %user_s2c,
|
||||||
|
total_bytes = total_bytes,
|
||||||
|
msgs = msg_count,
|
||||||
|
rate_kbps = (rate / 1024.0) as u64,
|
||||||
|
"S->C transfer in progress"
|
||||||
|
);
|
||||||
|
last_log = Instant::now();
|
||||||
|
}
|
||||||
|
|
||||||
if let Err(e) = client_writer.write_all(&buf[..n]).await {
|
if let Err(e) = client_writer.write_all(&buf[..n]).await {
|
||||||
debug!(user = %user_s2c, error = %e, "Failed to write to client");
|
debug!(user = %user_s2c, error = %e, "Failed to write to client");
|
||||||
break;
|
break;
|
||||||
@@ -130,7 +216,8 @@ where
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
|
||||||
|
Ok(Err(e)) => {
|
||||||
debug!(user = %user_s2c, error = %e, total_bytes = total_bytes, "Server read error");
|
debug!(user = %user_s2c, error = %e, total_bytes = total_bytes, "Server read error");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,8 +11,9 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
// ============= Configuration =============
|
// ============= Configuration =============
|
||||||
|
|
||||||
/// Default buffer size (64KB - good for MTProto)
|
/// Default buffer size
|
||||||
pub const DEFAULT_BUFFER_SIZE: usize = 64 * 1024;
|
/// CHANGED: Reduced from 64KB to 16KB to match TLS record size and prevent bufferbloat.
|
||||||
|
pub const DEFAULT_BUFFER_SIZE: usize = 16 * 1024;
|
||||||
|
|
||||||
/// Default maximum number of pooled buffers
|
/// Default maximum number of pooled buffers
|
||||||
pub const DEFAULT_MAX_BUFFERS: usize = 1024;
|
pub const DEFAULT_MAX_BUFFERS: usize = 1024;
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,17 +1,36 @@
|
|||||||
//! Fake TLS 1.3 stream wrappers
|
//! Fake TLS 1.3 stream wrappers
|
||||||
//!
|
//!
|
||||||
//! This module provides stateful async stream wrappers that handle
|
//! This module provides stateful async stream wrappers that handle TLS record
|
||||||
//! TLS record framing with proper partial read/write handling.
|
//! framing with proper partial read/write handling.
|
||||||
//!
|
//!
|
||||||
//! These are "fake" TLS streams - they wrap data in valid TLS 1.3
|
//! These are "fake" TLS streams:
|
||||||
//! Application Data records but don't perform actual TLS encryption.
|
//! - We wrap raw bytes into syntactically valid TLS 1.3 records (Application Data).
|
||||||
//! The actual encryption is handled by the crypto layer underneath.
|
//! - We DO NOT perform real TLS handshake/encryption.
|
||||||
|
//! - Real crypto for MTProto is handled by the crypto layer underneath.
|
||||||
|
//!
|
||||||
|
//! Why do we need this?
|
||||||
|
//! Telegram MTProto proxy "FakeTLS" mode uses a TLS-looking outer layer for
|
||||||
|
//! domain fronting / traffic camouflage. iOS Telegram clients are known to
|
||||||
|
//! produce slightly different TLS record sizing patterns than Android/Desktop,
|
||||||
|
//! including records that exceed 16384 payload bytes by a small overhead.
|
||||||
//!
|
//!
|
||||||
//! Key design principles:
|
//! Key design principles:
|
||||||
//! - Explicit state machines for all async operations
|
//! - Explicit state machines for all async operations
|
||||||
//! - Never lose data on partial reads
|
//! - Never lose data on partial reads
|
||||||
//! - Atomic TLS record formation for writes
|
//! - Atomic TLS record formation for writes
|
||||||
//! - Proper handling of all TLS record types
|
//! - Proper handling of all TLS record types
|
||||||
|
//!
|
||||||
|
//! Important nuance (Telegram FakeTLS):
|
||||||
|
//! - The TLS spec limits "plaintext fragments" to 2^14 (16384) bytes.
|
||||||
|
//! - However, the on-the-wire record length can exceed 16384 because TLS 1.3
|
||||||
|
//! uses AEAD and can include tag/overhead/padding.
|
||||||
|
//! - Telegram FakeTLS clients (notably iOS) may send Application Data records
|
||||||
|
//! with length up to 16384 + 24 bytes. We accept that as MAX_TLS_CHUNK_SIZE.
|
||||||
|
//!
|
||||||
|
//! If you reject those (e.g. validate length <= 16384), you will see errors like:
|
||||||
|
//! "TLS record too large: 16408 bytes"
|
||||||
|
//! and uploads from iOS will break (media/file sending), while small traffic
|
||||||
|
//! may still work.
|
||||||
|
|
||||||
use bytes::{Bytes, BytesMut, BufMut};
|
use bytes::{Bytes, BytesMut, BufMut};
|
||||||
use std::io::{self, Error, ErrorKind, Result};
|
use std::io::{self, Error, ErrorKind, Result};
|
||||||
@@ -20,25 +39,29 @@ use std::task::{Context, Poll};
|
|||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf};
|
||||||
|
|
||||||
use crate::protocol::constants::{
|
use crate::protocol::constants::{
|
||||||
TLS_VERSION, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER,
|
TLS_VERSION,
|
||||||
TLS_RECORD_HANDSHAKE, TLS_RECORD_ALERT, MAX_TLS_RECORD_SIZE,
|
TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER,
|
||||||
|
TLS_RECORD_HANDSHAKE, TLS_RECORD_ALERT,
|
||||||
|
MAX_TLS_CHUNK_SIZE,
|
||||||
};
|
};
|
||||||
use super::state::{StreamState, HeaderBuffer, YieldBuffer, WriteBuffer};
|
use super::state::{StreamState, HeaderBuffer, YieldBuffer, WriteBuffer};
|
||||||
|
|
||||||
// ============= Constants =============
|
// ============= Constants =============
|
||||||
|
|
||||||
/// TLS record header size
|
/// TLS record header size (type + version + length)
|
||||||
const TLS_HEADER_SIZE: usize = 5;
|
const TLS_HEADER_SIZE: usize = 5;
|
||||||
|
|
||||||
/// Maximum TLS record payload size (16KB as per TLS spec)
|
/// Maximum TLS fragment size per spec (plaintext fragment).
|
||||||
|
/// We use this for *outgoing* chunking, because we build plain ApplicationData records.
|
||||||
const MAX_TLS_PAYLOAD: usize = 16384;
|
const MAX_TLS_PAYLOAD: usize = 16384;
|
||||||
|
|
||||||
/// Maximum pending write buffer
|
/// Maximum pending write buffer for one record remainder.
|
||||||
|
/// Note: we never queue unlimited amount of data here; state holds at most one record.
|
||||||
const MAX_PENDING_WRITE: usize = 64 * 1024;
|
const MAX_PENDING_WRITE: usize = 64 * 1024;
|
||||||
|
|
||||||
// ============= TLS Record Types =============
|
// ============= TLS Record Types =============
|
||||||
|
|
||||||
/// Parsed TLS record header
|
/// Parsed TLS record header (5 bytes)
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
struct TlsRecordHeader {
|
struct TlsRecordHeader {
|
||||||
/// Record type (0x17 = Application Data, 0x14 = Change Cipher, etc.)
|
/// Record type (0x17 = Application Data, 0x14 = Change Cipher, etc.)
|
||||||
@@ -50,22 +73,27 @@ struct TlsRecordHeader {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl TlsRecordHeader {
|
impl TlsRecordHeader {
|
||||||
/// Parse header from 5 bytes
|
/// Parse header from exactly 5 bytes.
|
||||||
|
///
|
||||||
|
/// This currently never returns None, but is kept as Option to allow future
|
||||||
|
/// stricter parsing rules without changing callers.
|
||||||
fn parse(header: &[u8; 5]) -> Option<Self> {
|
fn parse(header: &[u8; 5]) -> Option<Self> {
|
||||||
let record_type = header[0];
|
let record_type = header[0];
|
||||||
let version = [header[1], header[2]];
|
let version = [header[1], header[2]];
|
||||||
let length = u16::from_be_bytes([header[3], header[4]]);
|
let length = u16::from_be_bytes([header[3], header[4]]);
|
||||||
|
Some(Self { record_type, version, length })
|
||||||
Some(Self {
|
|
||||||
record_type,
|
|
||||||
version,
|
|
||||||
length,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Validate the header
|
/// Validate the header.
|
||||||
|
///
|
||||||
|
/// Nuances:
|
||||||
|
/// - We accept TLS 1.0 header version for ClientHello-like records (0x03 0x01),
|
||||||
|
/// and TLS 1.2/1.3 style version bytes for the rest (we use TLS_VERSION = 0x03 0x03).
|
||||||
|
/// - For Application Data, Telegram FakeTLS may send payload length up to
|
||||||
|
/// MAX_TLS_CHUNK_SIZE (16384 + 24).
|
||||||
|
/// - For other record types we keep stricter bounds to avoid memory abuse.
|
||||||
fn validate(&self) -> Result<()> {
|
fn validate(&self) -> Result<()> {
|
||||||
// Check version (accept TLS 1.0 for ClientHello, TLS 1.2/1.3 for others)
|
// Version: accept TLS 1.0 header (ClientHello quirk) and TLS_VERSION (0x0303).
|
||||||
if self.version != [0x03, 0x01] && self.version != TLS_VERSION {
|
if self.version != [0x03, 0x01] && self.version != TLS_VERSION {
|
||||||
return Err(Error::new(
|
return Err(Error::new(
|
||||||
ErrorKind::InvalidData,
|
ErrorKind::InvalidData,
|
||||||
@@ -73,27 +101,36 @@ impl TlsRecordHeader {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check length
|
let len = self.length as usize;
|
||||||
if self.length as usize > MAX_TLS_RECORD_SIZE {
|
|
||||||
|
// Length checks depend on record type.
|
||||||
|
// Telegram FakeTLS: ApplicationData length may be 16384 + 24.
|
||||||
|
match self.record_type {
|
||||||
|
TLS_RECORD_APPLICATION => {
|
||||||
|
if len > MAX_TLS_CHUNK_SIZE {
|
||||||
return Err(Error::new(
|
return Err(Error::new(
|
||||||
ErrorKind::InvalidData,
|
ErrorKind::InvalidData,
|
||||||
format!("TLS record too large: {} bytes", self.length),
|
format!("TLS record too large: {} bytes (max {})", len, MAX_TLS_CHUNK_SIZE),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChangeCipherSpec/Alert/Handshake should never be that large for our usage
|
||||||
|
// (post-handshake we don't expect Handshake at all).
|
||||||
|
// Keep strict to reduce attack surface.
|
||||||
|
_ => {
|
||||||
|
if len > MAX_TLS_PAYLOAD {
|
||||||
|
return Err(Error::new(
|
||||||
|
ErrorKind::InvalidData,
|
||||||
|
format!("TLS control record too large: {} bytes (max {})", len, MAX_TLS_PAYLOAD),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if this is an application data record
|
|
||||||
fn is_application_data(&self) -> bool {
|
|
||||||
self.record_type == TLS_RECORD_APPLICATION
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if this is a change cipher spec record (should be skipped)
|
|
||||||
fn is_change_cipher_spec(&self) -> bool {
|
|
||||||
self.record_type == TLS_RECORD_CHANGE_CIPHER
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Build header bytes
|
/// Build header bytes
|
||||||
fn to_bytes(&self) -> [u8; 5] {
|
fn to_bytes(&self) -> [u8; 5] {
|
||||||
[
|
[
|
||||||
@@ -120,25 +157,20 @@ enum TlsReaderState {
|
|||||||
header: HeaderBuffer<TLS_HEADER_SIZE>,
|
header: HeaderBuffer<TLS_HEADER_SIZE>,
|
||||||
},
|
},
|
||||||
|
|
||||||
/// Reading the TLS record body
|
/// Reading the TLS record body (payload)
|
||||||
ReadingBody {
|
ReadingBody {
|
||||||
/// Parsed record type
|
|
||||||
record_type: u8,
|
record_type: u8,
|
||||||
/// Total body length
|
|
||||||
length: usize,
|
length: usize,
|
||||||
/// Buffer for body data
|
|
||||||
buffer: BytesMut,
|
buffer: BytesMut,
|
||||||
},
|
},
|
||||||
|
|
||||||
/// Have decrypted data ready to yield to caller
|
/// Have buffered data ready to yield to caller
|
||||||
Yielding {
|
Yielding {
|
||||||
/// Buffer containing data to yield
|
|
||||||
buffer: YieldBuffer,
|
buffer: YieldBuffer,
|
||||||
},
|
},
|
||||||
|
|
||||||
/// Stream encountered an error and cannot be used
|
/// Stream encountered an error and cannot be used
|
||||||
Poisoned {
|
Poisoned {
|
||||||
/// The error that caused poisoning
|
|
||||||
error: Option<io::Error>,
|
error: Option<io::Error>,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -165,12 +197,13 @@ impl StreamState for TlsReaderState {
|
|||||||
|
|
||||||
// ============= FakeTlsReader =============
|
// ============= FakeTlsReader =============
|
||||||
|
|
||||||
/// Reader that unwraps TLS 1.3 records with proper state machine
|
/// Reader that unwraps TLS records (FakeTLS).
|
||||||
///
|
///
|
||||||
/// This reader handles partial reads correctly by maintaining internal state
|
/// This wrapper is responsible ONLY for TLS record framing and skipping
|
||||||
/// and never losing any data that has been read from upstream.
|
/// non-data records (like CCS). It does not decrypt TLS: payload bytes are passed
|
||||||
|
/// as-is to upper layers (crypto stream).
|
||||||
///
|
///
|
||||||
/// # State Machine
|
/// State machine overview:
|
||||||
///
|
///
|
||||||
/// ┌──────────┐ ┌───────────────┐
|
/// ┌──────────┐ ┌───────────────┐
|
||||||
/// │ Idle │ -----------------> │ ReadingHeader │
|
/// │ Idle │ -----------------> │ ReadingHeader │
|
||||||
@@ -178,103 +211,69 @@ impl StreamState for TlsReaderState {
|
|||||||
/// ▲ │
|
/// ▲ │
|
||||||
/// │ header complete
|
/// │ header complete
|
||||||
/// │ │
|
/// │ │
|
||||||
/// │ │
|
/// │ ▼
|
||||||
/// │ ┌───────────────┐
|
/// │ ┌───────────────┐
|
||||||
/// │ skip record │ ReadingBody │
|
/// │ skip record │ ReadingBody │
|
||||||
/// │ <-------- (CCS) -------- │ │
|
/// │ <-------- (CCS) -------- │ │
|
||||||
/// │ └───────┬───────┘
|
/// │ └───────┬───────┘
|
||||||
/// │ │
|
/// │ │
|
||||||
/// │ body complete
|
/// │ body complete
|
||||||
/// │ drained │
|
/// │ ▼
|
||||||
/// │ <-----------------┐ │
|
/// │ ┌───────────────┐
|
||||||
/// │ │ ┌───────────────┐
|
/// │ │ Yielding │
|
||||||
/// │ └----- │ Yielding │
|
|
||||||
/// │ └───────────────┘
|
/// │ └───────────────┘
|
||||||
/// │
|
/// │
|
||||||
/// │ errors /w any state
|
/// │ errors / w any state
|
||||||
/// │
|
/// ▼
|
||||||
/// ┌───────────────────────────────────────────────┐
|
/// ┌───────────────────────────────────────────────┐
|
||||||
/// │ Poisoned │
|
/// │ Poisoned │
|
||||||
/// └───────────────────────────────────────────────┘
|
/// └───────────────────────────────────────────────┘
|
||||||
///
|
///
|
||||||
|
/// NOTE: We must correctly handle partial reads from upstream:
|
||||||
|
/// - do not assume header arrives in one poll
|
||||||
|
/// - do not assume body arrives in one poll
|
||||||
|
/// - never lose already-read bytes
|
||||||
pub struct FakeTlsReader<R> {
|
pub struct FakeTlsReader<R> {
|
||||||
/// Upstream reader
|
|
||||||
upstream: R,
|
upstream: R,
|
||||||
/// Current state
|
|
||||||
state: TlsReaderState,
|
state: TlsReaderState,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R> FakeTlsReader<R> {
|
impl<R> FakeTlsReader<R> {
|
||||||
/// Create new fake TLS reader
|
|
||||||
pub fn new(upstream: R) -> Self {
|
pub fn new(upstream: R) -> Self {
|
||||||
Self {
|
Self { upstream, state: TlsReaderState::Idle }
|
||||||
upstream,
|
|
||||||
state: TlsReaderState::Idle,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get reference to upstream
|
pub fn get_ref(&self) -> &R { &self.upstream }
|
||||||
pub fn get_ref(&self) -> &R {
|
pub fn get_mut(&mut self) -> &mut R { &mut self.upstream }
|
||||||
&self.upstream
|
pub fn into_inner(self) -> R { self.upstream }
|
||||||
}
|
|
||||||
|
|
||||||
/// Get mutable reference to upstream
|
pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() }
|
||||||
pub fn get_mut(&mut self) -> &mut R {
|
pub fn state_name(&self) -> &'static str { self.state.state_name() }
|
||||||
&mut self.upstream
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Consume and return upstream
|
|
||||||
pub fn into_inner(self) -> R {
|
|
||||||
self.upstream
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if stream is in poisoned state
|
|
||||||
pub fn is_poisoned(&self) -> bool {
|
|
||||||
self.state.is_poisoned()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get current state name (for debugging)
|
|
||||||
pub fn state_name(&self) -> &'static str {
|
|
||||||
self.state.state_name()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Transition to poisoned state
|
|
||||||
fn poison(&mut self, error: io::Error) {
|
fn poison(&mut self, error: io::Error) {
|
||||||
self.state = TlsReaderState::Poisoned { error: Some(error) };
|
self.state = TlsReaderState::Poisoned { error: Some(error) };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Take error from poisoned state
|
|
||||||
fn take_poison_error(&mut self) -> io::Error {
|
fn take_poison_error(&mut self) -> io::Error {
|
||||||
match &mut self.state {
|
match &mut self.state {
|
||||||
TlsReaderState::Poisoned { error } => {
|
TlsReaderState::Poisoned { error } => error.take().unwrap_or_else(|| {
|
||||||
error.take().unwrap_or_else(|| {
|
|
||||||
io::Error::new(ErrorKind::Other, "stream previously poisoned")
|
io::Error::new(ErrorKind::Other, "stream previously poisoned")
|
||||||
})
|
}),
|
||||||
}
|
|
||||||
_ => io::Error::new(ErrorKind::Other, "stream not poisoned"),
|
_ => io::Error::new(ErrorKind::Other, "stream not poisoned"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Result of polling for header completion
|
|
||||||
enum HeaderPollResult {
|
enum HeaderPollResult {
|
||||||
/// Need more data
|
|
||||||
Pending,
|
Pending,
|
||||||
/// EOF at record boundary (clean close)
|
|
||||||
Eof,
|
Eof,
|
||||||
/// Header complete, parsed successfully
|
|
||||||
Complete(TlsRecordHeader),
|
Complete(TlsRecordHeader),
|
||||||
/// Error occurred
|
|
||||||
Error(io::Error),
|
Error(io::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Result of polling for body completion
|
|
||||||
enum BodyPollResult {
|
enum BodyPollResult {
|
||||||
/// Need more data
|
|
||||||
Pending,
|
Pending,
|
||||||
/// Body complete
|
|
||||||
Complete(Bytes),
|
Complete(Bytes),
|
||||||
/// Error occurred
|
|
||||||
Error(io::Error),
|
Error(io::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -291,7 +290,7 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
|||||||
let state = std::mem::replace(&mut this.state, TlsReaderState::Idle);
|
let state = std::mem::replace(&mut this.state, TlsReaderState::Idle);
|
||||||
|
|
||||||
match state {
|
match state {
|
||||||
// Poisoned state - return error
|
// Poisoned state: always return the stored error
|
||||||
TlsReaderState::Poisoned { error } => {
|
TlsReaderState::Poisoned { error } => {
|
||||||
this.state = TlsReaderState::Poisoned { error: None };
|
this.state = TlsReaderState::Poisoned { error: None };
|
||||||
let err = error.unwrap_or_else(|| {
|
let err = error.unwrap_or_else(|| {
|
||||||
@@ -300,20 +299,18 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
|||||||
return Poll::Ready(Err(err));
|
return Poll::Ready(Err(err));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Have buffered data to yield
|
// Yield buffered plaintext to caller
|
||||||
TlsReaderState::Yielding { mut buffer } => {
|
TlsReaderState::Yielding { mut buffer } => {
|
||||||
if buf.remaining() == 0 {
|
if buf.remaining() == 0 {
|
||||||
this.state = TlsReaderState::Yielding { buffer };
|
this.state = TlsReaderState::Yielding { buffer };
|
||||||
return Poll::Ready(Ok(()));
|
return Poll::Ready(Ok(()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy as much as possible to output
|
|
||||||
let to_copy = buffer.remaining().min(buf.remaining());
|
let to_copy = buffer.remaining().min(buf.remaining());
|
||||||
let dst = buf.initialize_unfilled_to(to_copy);
|
let dst = buf.initialize_unfilled_to(to_copy);
|
||||||
let copied = buffer.copy_to(dst);
|
let copied = buffer.copy_to(dst);
|
||||||
buf.advance(copied);
|
buf.advance(copied);
|
||||||
|
|
||||||
// If buffer is drained, transition to Idle
|
|
||||||
if buffer.is_empty() {
|
if buffer.is_empty() {
|
||||||
this.state = TlsReaderState::Idle;
|
this.state = TlsReaderState::Idle;
|
||||||
} else {
|
} else {
|
||||||
@@ -323,23 +320,21 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
|||||||
return Poll::Ready(Ok(()));
|
return Poll::Ready(Ok(()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ready to read a new TLS record
|
// Start reading new record
|
||||||
TlsReaderState::Idle => {
|
TlsReaderState::Idle => {
|
||||||
if buf.remaining() == 0 {
|
if buf.remaining() == 0 {
|
||||||
this.state = TlsReaderState::Idle;
|
this.state = TlsReaderState::Idle;
|
||||||
return Poll::Ready(Ok(()));
|
return Poll::Ready(Ok(()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start reading header
|
|
||||||
this.state = TlsReaderState::ReadingHeader {
|
this.state = TlsReaderState::ReadingHeader {
|
||||||
header: HeaderBuffer::new(),
|
header: HeaderBuffer::new(),
|
||||||
};
|
};
|
||||||
// Continue to ReadingHeader
|
// loop continues and will handle ReadingHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reading TLS record header
|
// Read TLS header (5 bytes)
|
||||||
TlsReaderState::ReadingHeader { mut header } => {
|
TlsReaderState::ReadingHeader { mut header } => {
|
||||||
// Poll to fill header
|
|
||||||
let result = poll_read_header(&mut this.upstream, cx, &mut header);
|
let result = poll_read_header(&mut this.upstream, cx, &mut header);
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
@@ -348,6 +343,7 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
|||||||
return Poll::Pending;
|
return Poll::Pending;
|
||||||
}
|
}
|
||||||
HeaderPollResult::Eof => {
|
HeaderPollResult::Eof => {
|
||||||
|
// Clean EOF at record boundary
|
||||||
this.state = TlsReaderState::Idle;
|
this.state = TlsReaderState::Idle;
|
||||||
return Poll::Ready(Ok(()));
|
return Poll::Ready(Ok(()));
|
||||||
}
|
}
|
||||||
@@ -356,15 +352,12 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
|||||||
return Poll::Ready(Err(e));
|
return Poll::Ready(Err(e));
|
||||||
}
|
}
|
||||||
HeaderPollResult::Complete(parsed) => {
|
HeaderPollResult::Complete(parsed) => {
|
||||||
// Validate header
|
|
||||||
if let Err(e) = parsed.validate() {
|
if let Err(e) = parsed.validate() {
|
||||||
this.poison(Error::new(e.kind(), e.to_string()));
|
this.poison(Error::new(e.kind(), e.to_string()));
|
||||||
return Poll::Ready(Err(e));
|
return Poll::Ready(Err(e));
|
||||||
}
|
}
|
||||||
|
|
||||||
let length = parsed.length as usize;
|
let length = parsed.length as usize;
|
||||||
|
|
||||||
// Transition to reading body
|
|
||||||
this.state = TlsReaderState::ReadingBody {
|
this.state = TlsReaderState::ReadingBody {
|
||||||
record_type: parsed.record_type,
|
record_type: parsed.record_type,
|
||||||
length,
|
length,
|
||||||
@@ -374,7 +367,7 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reading TLS record body
|
// Read TLS payload
|
||||||
TlsReaderState::ReadingBody { record_type, length, mut buffer } => {
|
TlsReaderState::ReadingBody { record_type, length, mut buffer } => {
|
||||||
let result = poll_read_body(&mut this.upstream, cx, &mut buffer, length);
|
let result = poll_read_body(&mut this.upstream, cx, &mut buffer, length);
|
||||||
|
|
||||||
@@ -388,15 +381,15 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
|||||||
return Poll::Ready(Err(e));
|
return Poll::Ready(Err(e));
|
||||||
}
|
}
|
||||||
BodyPollResult::Complete(data) => {
|
BodyPollResult::Complete(data) => {
|
||||||
// Handle different record types
|
|
||||||
match record_type {
|
match record_type {
|
||||||
TLS_RECORD_CHANGE_CIPHER => {
|
TLS_RECORD_CHANGE_CIPHER => {
|
||||||
// Skip Change Cipher Spec, read next record
|
// CCS is expected in some clients, ignore it.
|
||||||
this.state = TlsReaderState::Idle;
|
this.state = TlsReaderState::Idle;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
TLS_RECORD_APPLICATION => {
|
TLS_RECORD_APPLICATION => {
|
||||||
// Application data - yield to caller
|
// This is what we actually want.
|
||||||
if data.is_empty() {
|
if data.is_empty() {
|
||||||
this.state = TlsReaderState::Idle;
|
this.state = TlsReaderState::Idle;
|
||||||
continue;
|
continue;
|
||||||
@@ -405,25 +398,26 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
|||||||
this.state = TlsReaderState::Yielding {
|
this.state = TlsReaderState::Yielding {
|
||||||
buffer: YieldBuffer::new(data),
|
buffer: YieldBuffer::new(data),
|
||||||
};
|
};
|
||||||
// Continue to yield
|
// loop continues and will yield immediately
|
||||||
}
|
}
|
||||||
|
|
||||||
TLS_RECORD_ALERT => {
|
TLS_RECORD_ALERT => {
|
||||||
// TLS Alert - treat as EOF
|
// Treat TLS alert as EOF-like termination.
|
||||||
this.state = TlsReaderState::Idle;
|
this.state = TlsReaderState::Idle;
|
||||||
return Poll::Ready(Ok(()));
|
return Poll::Ready(Ok(()));
|
||||||
}
|
}
|
||||||
|
|
||||||
TLS_RECORD_HANDSHAKE => {
|
TLS_RECORD_HANDSHAKE => {
|
||||||
let err = Error::new(
|
// After FakeTLS handshake is done, we do not expect any Handshake records.
|
||||||
ErrorKind::InvalidData,
|
let err = Error::new(ErrorKind::InvalidData, "unexpected TLS handshake record");
|
||||||
"unexpected TLS handshake record"
|
|
||||||
);
|
|
||||||
this.poison(Error::new(err.kind(), err.to_string()));
|
this.poison(Error::new(err.kind(), err.to_string()));
|
||||||
return Poll::Ready(Err(err));
|
return Poll::Ready(Err(err));
|
||||||
}
|
}
|
||||||
|
|
||||||
_ => {
|
_ => {
|
||||||
let err = Error::new(
|
let err = Error::new(
|
||||||
ErrorKind::InvalidData,
|
ErrorKind::InvalidData,
|
||||||
format!("unknown TLS record type: 0x{:02x}", record_type)
|
format!("unknown TLS record type: 0x{:02x}", record_type),
|
||||||
);
|
);
|
||||||
this.poison(Error::new(err.kind(), err.to_string()));
|
this.poison(Error::new(err.kind(), err.to_string()));
|
||||||
return Poll::Ready(Err(err));
|
return Poll::Ready(Err(err));
|
||||||
@@ -459,8 +453,10 @@ fn poll_read_header<R: AsyncRead + Unpin>(
|
|||||||
} else {
|
} else {
|
||||||
return HeaderPollResult::Error(Error::new(
|
return HeaderPollResult::Error(Error::new(
|
||||||
ErrorKind::UnexpectedEof,
|
ErrorKind::UnexpectedEof,
|
||||||
format!("unexpected EOF in TLS header (got {} of 5 bytes)",
|
format!(
|
||||||
header.as_slice().len())
|
"unexpected EOF in TLS header (got {} of 5 bytes)",
|
||||||
|
header.as_slice().len()
|
||||||
|
),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -469,14 +465,10 @@ fn poll_read_header<R: AsyncRead + Unpin>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse header
|
|
||||||
let header_bytes = *header.as_array();
|
let header_bytes = *header.as_array();
|
||||||
match TlsRecordHeader::parse(&header_bytes) {
|
match TlsRecordHeader::parse(&header_bytes) {
|
||||||
Some(h) => HeaderPollResult::Complete(h),
|
Some(h) => HeaderPollResult::Complete(h),
|
||||||
None => HeaderPollResult::Error(Error::new(
|
None => HeaderPollResult::Error(Error::new(ErrorKind::InvalidData, "failed to parse TLS header")),
|
||||||
ErrorKind::InvalidData,
|
|
||||||
"failed to parse TLS header"
|
|
||||||
)),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -487,10 +479,12 @@ fn poll_read_body<R: AsyncRead + Unpin>(
|
|||||||
buffer: &mut BytesMut,
|
buffer: &mut BytesMut,
|
||||||
target_len: usize,
|
target_len: usize,
|
||||||
) -> BodyPollResult {
|
) -> BodyPollResult {
|
||||||
|
// NOTE: This implementation uses a temporary Vec to avoid tricky borrow/lifetime
|
||||||
|
// issues with BytesMut spare capacity and ReadBuf across polls.
|
||||||
|
// It's safe and correct; optimization is possible if needed.
|
||||||
while buffer.len() < target_len {
|
while buffer.len() < target_len {
|
||||||
let remaining = target_len - buffer.len();
|
let remaining = target_len - buffer.len();
|
||||||
|
|
||||||
// Read into a temporary buffer
|
|
||||||
let mut temp = vec![0u8; remaining.min(8192)];
|
let mut temp = vec![0u8; remaining.min(8192)];
|
||||||
let mut read_buf = ReadBuf::new(&mut temp);
|
let mut read_buf = ReadBuf::new(&mut temp);
|
||||||
|
|
||||||
@@ -502,8 +496,11 @@ fn poll_read_body<R: AsyncRead + Unpin>(
|
|||||||
if n == 0 {
|
if n == 0 {
|
||||||
return BodyPollResult::Error(Error::new(
|
return BodyPollResult::Error(Error::new(
|
||||||
ErrorKind::UnexpectedEof,
|
ErrorKind::UnexpectedEof,
|
||||||
format!("unexpected EOF in TLS body (got {} of {} bytes)",
|
format!(
|
||||||
buffer.len(), target_len)
|
"unexpected EOF in TLS body (got {} of {} bytes)",
|
||||||
|
buffer.len(),
|
||||||
|
target_len
|
||||||
|
),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
buffer.extend_from_slice(&temp[..n]);
|
buffer.extend_from_slice(&temp[..n]);
|
||||||
@@ -515,10 +512,9 @@ fn poll_read_body<R: AsyncRead + Unpin>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<R: AsyncRead + Unpin> FakeTlsReader<R> {
|
impl<R: AsyncRead + Unpin> FakeTlsReader<R> {
|
||||||
/// Read exactly n bytes through TLS layer
|
/// Read exactly n bytes through TLS layer.
|
||||||
///
|
///
|
||||||
/// This is a convenience method that accumulates data across
|
/// This accumulates data across multiple TLS ApplicationData records.
|
||||||
/// multiple TLS records until exactly n bytes are available.
|
|
||||||
pub async fn read_exact(&mut self, n: usize) -> Result<Bytes> {
|
pub async fn read_exact(&mut self, n: usize) -> Result<Bytes> {
|
||||||
if self.is_poisoned() {
|
if self.is_poisoned() {
|
||||||
return Err(self.take_poison_error());
|
return Err(self.take_poison_error());
|
||||||
@@ -533,7 +529,7 @@ impl<R: AsyncRead + Unpin> FakeTlsReader<R> {
|
|||||||
if read == 0 {
|
if read == 0 {
|
||||||
return Err(Error::new(
|
return Err(Error::new(
|
||||||
ErrorKind::UnexpectedEof,
|
ErrorKind::UnexpectedEof,
|
||||||
format!("expected {} bytes, got {}", n, result.len())
|
format!("expected {} bytes, got {}", n, result.len()),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -546,23 +542,19 @@ impl<R: AsyncRead + Unpin> FakeTlsReader<R> {
|
|||||||
|
|
||||||
// ============= FakeTlsWriter State =============
|
// ============= FakeTlsWriter State =============
|
||||||
|
|
||||||
/// State machine states for FakeTlsWriter
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
enum TlsWriterState {
|
enum TlsWriterState {
|
||||||
/// Ready to accept new data
|
/// Ready to accept new data
|
||||||
Idle,
|
Idle,
|
||||||
|
|
||||||
/// Writing a complete TLS record
|
/// Writing a complete TLS record (header + body), possibly partially
|
||||||
WritingRecord {
|
WritingRecord {
|
||||||
/// Complete record (header + body) to write
|
|
||||||
record: WriteBuffer,
|
record: WriteBuffer,
|
||||||
/// Original payload size (for return value calculation)
|
|
||||||
payload_size: usize,
|
payload_size: usize,
|
||||||
},
|
},
|
||||||
|
|
||||||
/// Stream encountered an error and cannot be used
|
/// Stream encountered an error and cannot be used
|
||||||
Poisoned {
|
Poisoned {
|
||||||
/// The error that caused poisoning
|
|
||||||
error: Option<io::Error>,
|
error: Option<io::Error>,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -587,94 +579,46 @@ impl StreamState for TlsWriterState {
|
|||||||
|
|
||||||
// ============= FakeTlsWriter =============
|
// ============= FakeTlsWriter =============
|
||||||
|
|
||||||
/// Writer that wraps data in TLS 1.3 records with proper state machine
|
/// Writer that wraps bytes into TLS 1.3 Application Data records.
|
||||||
///
|
///
|
||||||
/// This writer handles partial writes correctly by:
|
/// We chunk outgoing data into records of <= 16384 payload bytes (MAX_TLS_PAYLOAD).
|
||||||
/// - Building complete TLS records before writing
|
/// We do not try to mimic AEAD overhead on the wire; Telegram clients accept it.
|
||||||
/// - Maintaining internal state for partial record writes
|
/// If you want to be more camouflage-accurate later, you could add optional padding
|
||||||
/// - Never splitting a record mid-write to upstream
|
/// to produce records sized closer to MAX_TLS_CHUNK_SIZE.
|
||||||
///
|
|
||||||
/// # State Machine
|
|
||||||
///
|
|
||||||
/// ┌──────────┐ write ┌─────────────────┐
|
|
||||||
/// │ Idle │ -------------> │ WritingRecord │
|
|
||||||
/// │ │ <------------- │ │
|
|
||||||
/// └──────────┘ complete └─────────────────┘
|
|
||||||
/// │ │
|
|
||||||
/// │ < errors > │
|
|
||||||
/// │ │
|
|
||||||
/// ┌─────────────────────────────────────────────┐
|
|
||||||
/// │ Poisoned │
|
|
||||||
/// └─────────────────────────────────────────────┘
|
|
||||||
///
|
|
||||||
/// # Record Formation
|
|
||||||
///
|
|
||||||
/// Data is chunked into records of at most MAX_TLS_PAYLOAD bytes.
|
|
||||||
/// Each record has a 5-byte header prepended.
|
|
||||||
pub struct FakeTlsWriter<W> {
|
pub struct FakeTlsWriter<W> {
|
||||||
/// Upstream writer
|
|
||||||
upstream: W,
|
upstream: W,
|
||||||
/// Current state
|
|
||||||
state: TlsWriterState,
|
state: TlsWriterState,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<W> FakeTlsWriter<W> {
|
impl<W> FakeTlsWriter<W> {
|
||||||
/// Create new fake TLS writer
|
|
||||||
pub fn new(upstream: W) -> Self {
|
pub fn new(upstream: W) -> Self {
|
||||||
Self {
|
Self { upstream, state: TlsWriterState::Idle }
|
||||||
upstream,
|
|
||||||
state: TlsWriterState::Idle,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get reference to upstream
|
pub fn get_ref(&self) -> &W { &self.upstream }
|
||||||
pub fn get_ref(&self) -> &W {
|
pub fn get_mut(&mut self) -> &mut W { &mut self.upstream }
|
||||||
&self.upstream
|
pub fn into_inner(self) -> W { self.upstream }
|
||||||
}
|
|
||||||
|
|
||||||
/// Get mutable reference to upstream
|
pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() }
|
||||||
pub fn get_mut(&mut self) -> &mut W {
|
pub fn state_name(&self) -> &'static str { self.state.state_name() }
|
||||||
&mut self.upstream
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Consume and return upstream
|
|
||||||
pub fn into_inner(self) -> W {
|
|
||||||
self.upstream
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if stream is in poisoned state
|
|
||||||
pub fn is_poisoned(&self) -> bool {
|
|
||||||
self.state.is_poisoned()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get current state name (for debugging)
|
|
||||||
pub fn state_name(&self) -> &'static str {
|
|
||||||
self.state.state_name()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if there's a pending record to write
|
|
||||||
pub fn has_pending(&self) -> bool {
|
pub fn has_pending(&self) -> bool {
|
||||||
matches!(&self.state, TlsWriterState::WritingRecord { record, .. } if !record.is_empty())
|
matches!(&self.state, TlsWriterState::WritingRecord { record, .. } if !record.is_empty())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Transition to poisoned state
|
|
||||||
fn poison(&mut self, error: io::Error) {
|
fn poison(&mut self, error: io::Error) {
|
||||||
self.state = TlsWriterState::Poisoned { error: Some(error) };
|
self.state = TlsWriterState::Poisoned { error: Some(error) };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Take error from poisoned state
|
|
||||||
fn take_poison_error(&mut self) -> io::Error {
|
fn take_poison_error(&mut self) -> io::Error {
|
||||||
match &mut self.state {
|
match &mut self.state {
|
||||||
TlsWriterState::Poisoned { error } => {
|
TlsWriterState::Poisoned { error } => error.take().unwrap_or_else(|| {
|
||||||
error.take().unwrap_or_else(|| {
|
|
||||||
io::Error::new(ErrorKind::Other, "stream previously poisoned")
|
io::Error::new(ErrorKind::Other, "stream previously poisoned")
|
||||||
})
|
}),
|
||||||
}
|
|
||||||
_ => io::Error::new(ErrorKind::Other, "stream not poisoned"),
|
_ => io::Error::new(ErrorKind::Other, "stream not poisoned"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build a TLS Application Data record
|
|
||||||
fn build_record(data: &[u8]) -> BytesMut {
|
fn build_record(data: &[u8]) -> BytesMut {
|
||||||
let header = TlsRecordHeader {
|
let header = TlsRecordHeader {
|
||||||
record_type: TLS_RECORD_APPLICATION,
|
record_type: TLS_RECORD_APPLICATION,
|
||||||
@@ -689,18 +633,13 @@ impl<W> FakeTlsWriter<W> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Result of flushing pending record
|
|
||||||
enum FlushResult {
|
enum FlushResult {
|
||||||
/// All data flushed, returns payload size
|
|
||||||
Complete(usize),
|
Complete(usize),
|
||||||
/// Need to wait for upstream
|
|
||||||
Pending,
|
Pending,
|
||||||
/// Error occurred
|
|
||||||
Error(io::Error),
|
Error(io::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
|
impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
|
||||||
/// Try to flush pending record to upstream (standalone logic)
|
|
||||||
fn poll_flush_record_inner(
|
fn poll_flush_record_inner(
|
||||||
upstream: &mut W,
|
upstream: &mut W,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
@@ -710,19 +649,14 @@ impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
|
|||||||
let data = record.pending();
|
let data = record.pending();
|
||||||
match Pin::new(&mut *upstream).poll_write(cx, data) {
|
match Pin::new(&mut *upstream).poll_write(cx, data) {
|
||||||
Poll::Pending => return FlushResult::Pending,
|
Poll::Pending => return FlushResult::Pending,
|
||||||
|
|
||||||
Poll::Ready(Err(e)) => return FlushResult::Error(e),
|
Poll::Ready(Err(e)) => return FlushResult::Error(e),
|
||||||
|
|
||||||
Poll::Ready(Ok(0)) => {
|
Poll::Ready(Ok(0)) => {
|
||||||
return FlushResult::Error(Error::new(
|
return FlushResult::Error(Error::new(
|
||||||
ErrorKind::WriteZero,
|
ErrorKind::WriteZero,
|
||||||
"upstream returned 0 bytes written"
|
"upstream returned 0 bytes written",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
Poll::Ready(Ok(n)) => record.advance(n),
|
||||||
Poll::Ready(Ok(n)) => {
|
|
||||||
record.advance(n);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -738,7 +672,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
|||||||
) -> Poll<Result<usize>> {
|
) -> Poll<Result<usize>> {
|
||||||
let this = self.get_mut();
|
let this = self.get_mut();
|
||||||
|
|
||||||
// Take ownership of state
|
// Take ownership of state to avoid borrow conflicts.
|
||||||
let state = std::mem::replace(&mut this.state, TlsWriterState::Idle);
|
let state = std::mem::replace(&mut this.state, TlsWriterState::Idle);
|
||||||
|
|
||||||
match state {
|
match state {
|
||||||
@@ -751,7 +685,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TlsWriterState::WritingRecord { mut record, payload_size } => {
|
TlsWriterState::WritingRecord { mut record, payload_size } => {
|
||||||
// Continue flushing existing record
|
// Finish writing previous record before accepting new bytes.
|
||||||
match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) {
|
match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) {
|
||||||
FlushResult::Pending => {
|
FlushResult::Pending => {
|
||||||
this.state = TlsWriterState::WritingRecord { record, payload_size };
|
this.state = TlsWriterState::WritingRecord { record, payload_size };
|
||||||
@@ -763,7 +697,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
|||||||
}
|
}
|
||||||
FlushResult::Complete(_) => {
|
FlushResult::Complete(_) => {
|
||||||
this.state = TlsWriterState::Idle;
|
this.state = TlsWriterState::Idle;
|
||||||
// Fall through to handle new write
|
// continue to accept new buf below
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -782,19 +716,18 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
|||||||
let chunk_size = buf.len().min(MAX_TLS_PAYLOAD);
|
let chunk_size = buf.len().min(MAX_TLS_PAYLOAD);
|
||||||
let chunk = &buf[..chunk_size];
|
let chunk = &buf[..chunk_size];
|
||||||
|
|
||||||
// Build the complete record
|
// Build the complete record (header + payload)
|
||||||
let record_data = Self::build_record(chunk);
|
let record_data = Self::build_record(chunk);
|
||||||
|
|
||||||
// Try to write directly first
|
|
||||||
match Pin::new(&mut this.upstream).poll_write(cx, &record_data) {
|
match Pin::new(&mut this.upstream).poll_write(cx, &record_data) {
|
||||||
Poll::Ready(Ok(n)) if n == record_data.len() => {
|
Poll::Ready(Ok(n)) if n == record_data.len() => {
|
||||||
// Complete record written
|
|
||||||
Poll::Ready(Ok(chunk_size))
|
Poll::Ready(Ok(chunk_size))
|
||||||
}
|
}
|
||||||
|
|
||||||
Poll::Ready(Ok(n)) => {
|
Poll::Ready(Ok(n)) => {
|
||||||
// Partial write - buffer the rest
|
// Partial write of the record: store remainder.
|
||||||
let mut write_buffer = WriteBuffer::with_max_size(MAX_PENDING_WRITE);
|
let mut write_buffer = WriteBuffer::with_max_size(MAX_PENDING_WRITE);
|
||||||
|
// record_data length is <= 16389, fits MAX_PENDING_WRITE
|
||||||
let _ = write_buffer.extend(&record_data[n..]);
|
let _ = write_buffer.extend(&record_data[n..]);
|
||||||
|
|
||||||
this.state = TlsWriterState::WritingRecord {
|
this.state = TlsWriterState::WritingRecord {
|
||||||
@@ -802,7 +735,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
|||||||
payload_size: chunk_size,
|
payload_size: chunk_size,
|
||||||
};
|
};
|
||||||
|
|
||||||
// We've accepted chunk_size bytes from caller
|
// We have accepted chunk_size bytes from caller.
|
||||||
Poll::Ready(Ok(chunk_size))
|
Poll::Ready(Ok(chunk_size))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -812,7 +745,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Poll::Pending => {
|
Poll::Pending => {
|
||||||
// Buffer the entire record
|
// Buffer entire record and report success for this chunk.
|
||||||
let mut write_buffer = WriteBuffer::with_max_size(MAX_PENDING_WRITE);
|
let mut write_buffer = WriteBuffer::with_max_size(MAX_PENDING_WRITE);
|
||||||
let _ = write_buffer.extend(&record_data);
|
let _ = write_buffer.extend(&record_data);
|
||||||
|
|
||||||
@@ -821,10 +754,9 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
|||||||
payload_size: chunk_size,
|
payload_size: chunk_size,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Wake to try again
|
// Wake to retry flushing soon.
|
||||||
cx.waker().wake_by_ref();
|
cx.waker().wake_by_ref();
|
||||||
|
|
||||||
// We've accepted chunk_size bytes from caller
|
|
||||||
Poll::Ready(Ok(chunk_size))
|
Poll::Ready(Ok(chunk_size))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -833,7 +765,6 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
|||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||||
let this = self.get_mut();
|
let this = self.get_mut();
|
||||||
|
|
||||||
// Take ownership of state
|
|
||||||
let state = std::mem::replace(&mut this.state, TlsWriterState::Idle);
|
let state = std::mem::replace(&mut this.state, TlsWriterState::Idle);
|
||||||
|
|
||||||
match state {
|
match state {
|
||||||
@@ -866,48 +797,33 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush upstream
|
|
||||||
Pin::new(&mut this.upstream).poll_flush(cx)
|
Pin::new(&mut this.upstream).poll_flush(cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||||
let this = self.get_mut();
|
let this = self.get_mut();
|
||||||
|
|
||||||
// Take ownership of state
|
|
||||||
let state = std::mem::replace(&mut this.state, TlsWriterState::Idle);
|
let state = std::mem::replace(&mut this.state, TlsWriterState::Idle);
|
||||||
|
|
||||||
match state {
|
match state {
|
||||||
TlsWriterState::WritingRecord { mut record, payload_size } => {
|
TlsWriterState::WritingRecord { mut record, payload_size: _ } => {
|
||||||
// Try to flush pending (best effort)
|
// Best-effort flush (do not block shutdown forever).
|
||||||
match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) {
|
let _ = Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record);
|
||||||
FlushResult::Pending => {
|
|
||||||
// Can't complete flush, continue with shutdown anyway
|
|
||||||
this.state = TlsWriterState::Idle;
|
this.state = TlsWriterState::Idle;
|
||||||
}
|
}
|
||||||
FlushResult::Error(_) => {
|
|
||||||
// Ignore errors during shutdown
|
|
||||||
this.state = TlsWriterState::Idle;
|
|
||||||
}
|
|
||||||
FlushResult::Complete(_) => {
|
|
||||||
this.state = TlsWriterState::Idle;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {
|
_ => {
|
||||||
this.state = TlsWriterState::Idle;
|
this.state = TlsWriterState::Idle;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown upstream
|
|
||||||
Pin::new(&mut this.upstream).poll_shutdown(cx)
|
Pin::new(&mut this.upstream).poll_shutdown(cx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
|
impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
|
||||||
/// Write all data wrapped in TLS records (async method)
|
/// Write all data wrapped in TLS records.
|
||||||
///
|
///
|
||||||
/// This convenience method handles chunking large data into
|
/// Convenience method that chunks into <= 16384 records.
|
||||||
/// multiple TLS records automatically.
|
|
||||||
pub async fn write_all_tls(&mut self, data: &[u8]) -> Result<()> {
|
pub async fn write_all_tls(&mut self, data: &[u8]) -> Result<()> {
|
||||||
let mut written = 0;
|
let mut written = 0;
|
||||||
while written < data.len() {
|
while written < data.len() {
|
||||||
|
|||||||
@@ -30,20 +30,13 @@ pub fn configure_tcp_socket(
|
|||||||
socket.set_tcp_keepalive(&keepalive)?;
|
socket.set_tcp_keepalive(&keepalive)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set buffer sizes
|
// CHANGED: Removed manual buffer size setting (was 256KB).
|
||||||
set_buffer_sizes(&socket, 65536, 65536)?;
|
// Allowing the OS kernel to handle TCP window scaling (Autotuning) is critical
|
||||||
|
// for mobile clients to avoid bufferbloat and stalled connections during uploads.
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set socket buffer sizes
|
|
||||||
fn set_buffer_sizes(socket: &socket2::SockRef, recv: usize, send: usize) -> Result<()> {
|
|
||||||
// These may fail on some systems, so we ignore errors
|
|
||||||
let _ = socket.set_recv_buffer_size(recv);
|
|
||||||
let _ = socket.set_send_buffer_size(send);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Configure socket for accepting client connections
|
/// Configure socket for accepting client connections
|
||||||
pub fn configure_client_socket(
|
pub fn configure_client_socket(
|
||||||
stream: &TcpStream,
|
stream: &TcpStream,
|
||||||
@@ -65,6 +58,8 @@ pub fn configure_client_socket(
|
|||||||
socket.set_tcp_keepalive(&keepalive)?;
|
socket.set_tcp_keepalive(&keepalive)?;
|
||||||
|
|
||||||
// Set TCP user timeout (Linux only)
|
// Set TCP user timeout (Linux only)
|
||||||
|
// NOTE: iOS does not support TCP_USER_TIMEOUT - application-level timeout
|
||||||
|
// is implemented in relay_bidirectional instead
|
||||||
#[cfg(target_os = "linux")]
|
#[cfg(target_os = "linux")]
|
||||||
{
|
{
|
||||||
use std::os::unix::io::AsRawFd;
|
use std::os::unix::io::AsRawFd;
|
||||||
|
|||||||
Reference in New Issue
Block a user