diff --git a/src/config/mod.rs b/src/config/mod.rs index 07ca324..ef3fa35 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,7 +1,7 @@ //! Configuration use std::collections::HashMap; -use std::net::IpAddr; +use std::net::{IpAddr, SocketAddr}; use std::path::Path; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; @@ -336,6 +336,21 @@ pub struct ProxyConfig { #[serde(default)] pub show_link: Vec, + + /// DC address overrides for non-standard DCs (CDN, media, test, etc.) + /// Keys are DC indices as strings, values are "ip:port" addresses. + /// Matches the C implementation's `proxy_for :` config directive. + /// Example in config.toml: + /// [dc_overrides] + /// "203" = "149.154.175.100:443" + #[serde(default)] + pub dc_overrides: HashMap, + + /// Default DC index (1-5) for unmapped non-standard DCs. + /// Matches the C implementation's `default ` config directive. + /// If not set, defaults to 2 (matching Telegram's official `default 2;` in proxy-multi.conf). + #[serde(default)] + pub default_dc: Option, } impl ProxyConfig { diff --git a/src/main.rs b/src/main.rs index 5fc1502..d91bf2b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use std::time::Duration; use tokio::net::TcpListener; use tokio::signal; +use tokio::sync::Semaphore; use tracing::{info, error, warn, debug}; use tracing_subscriber::{fmt, EnvFilter, reload, prelude::*}; @@ -151,6 +152,11 @@ async fn main() -> Result<(), Box> { let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone())); let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096)); + // Connection concurrency limit — prevents OOM under SYN flood / connection storm. + // 10000 is generous; each connection uses ~64KB (2x 16KB relay buffers + overhead). + // 10000 connections ≈ 640MB peak memory. + let max_connections = Arc::new(Semaphore::new(10_000)); + // Startup DC ping info!("=== Telegram DC Connectivity ==="); let ping_results = upstream_manager.ping_all_dcs(prefer_ipv6).await; diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 107cb7b..adcb25b 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -298,20 +298,70 @@ impl RunningClientHandler { Ok(()) } + /// Resolve DC index to a target address. + /// + /// Matches the C implementation's behavior exactly: + /// + /// 1. Look up DC in known clusters (standard DCs ±1..±5) + /// 2. If not found and `force=1` → fall back to `default_cluster` + /// + /// In the C code: + /// - `proxy-multi.conf` is downloaded from Telegram, contains only DC ±1..±5 + /// - `default 2;` directive sets the default cluster + /// - `mf_cluster_lookup(CurConf, target_dc, 1)` returns default_cluster + /// for any unknown DC (like CDN DC 203) + /// + /// So DC 203, DC 101, DC -300, etc. all route to the default DC (2). + /// There is NO modular arithmetic in the C implementation. fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { - let idx = (dc_idx.abs() - 1) as usize; - let datacenters = if config.general.prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 }; - - datacenters.get(idx) - .map(|ip| SocketAddr::new(*ip, TG_DATACENTER_PORT)) - .ok_or_else(|| ProxyError::InvalidHandshake( - format!("Invalid DC index: {}", dc_idx) - )) + + let num_dcs = datacenters.len(); // 5 + + // === Step 1: Check dc_overrides (like C's `proxy_for :`) === + let dc_key = dc_idx.to_string(); + if let Some(addr_str) = config.dc_overrides.get(&dc_key) { + match addr_str.parse::() { + Ok(addr) => { + debug!(dc_idx = dc_idx, addr = %addr, "Using DC override from config"); + return Ok(addr); + } + Err(_) => { + warn!(dc_idx = dc_idx, addr_str = %addr_str, + "Invalid DC override address in config, ignoring"); + } + } + } + + // === Step 2: Standard DCs ±1..±5 — direct lookup === + let abs_dc = dc_idx.unsigned_abs() as usize; + if abs_dc >= 1 && abs_dc <= num_dcs { + return Ok(SocketAddr::new(datacenters[abs_dc - 1], TG_DATACENTER_PORT)); + } + + // === Step 3: Unknown DC — fall back to default_cluster === + // Exactly like C's `mf_cluster_lookup(CurConf, target_dc, force=1)` + // which returns `MC->default_cluster` when the DC is not found. + // Telegram's proxy-multi.conf uses `default 2;` + let default_dc = config.default_dc.unwrap_or(2) as usize; + let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs { + default_dc - 1 + } else { + 1 // DC 2 (index 1) — matches Telegram's `default 2;` + }; + + info!( + original_dc = dc_idx, + fallback_dc = (fallback_idx + 1) as u16, + fallback_addr = %datacenters[fallback_idx], + "Special DC ---> default_cluster" + ); + + Ok(SocketAddr::new(datacenters[fallback_idx], TG_DATACENTER_PORT)) } async fn do_tg_handshake_static( diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 27bb867..e81804a 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -9,6 +9,9 @@ use tracing::debug; use crate::config::ProxyConfig; const MASK_TIMEOUT: Duration = Duration::from_secs(5); + /// Maximum duration for the entire masking relay. + /// Limits resource consumption from slow-loris attacks and port scanners. + const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60); const MASK_BUFFER_SIZE: usize = 8192; /// Detect client type based on initial data diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index 4246e18..cd85304 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -1,262 +1,195 @@ //! Bidirectional Relay - -use std::sync::Arc; -use std::time::Duration; -use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; -use tokio::time::Instant; -use tracing::{debug, trace, warn, info}; -use crate::error::Result; -use crate::stats::Stats; -use crate::stream::BufferPool; -use std::sync::atomic::{AtomicU64, Ordering}; - -// Activity timeout for iOS compatibility (30 minutes) -const ACTIVITY_TIMEOUT_SECS: u64 = 1800; - -/// Relay data bidirectionally between client and server -pub async fn relay_bidirectional( - mut client_reader: CR, - mut client_writer: CW, - mut server_reader: SR, - mut server_writer: SW, - user: &str, - stats: Arc, - buffer_pool: Arc, -) -> Result<()> -where - CR: AsyncRead + Unpin + Send + 'static, - CW: AsyncWrite + Unpin + Send + 'static, - SR: AsyncRead + Unpin + Send + 'static, - SW: AsyncWrite + Unpin + Send + 'static, -{ - let user_c2s = user.to_string(); - let user_s2c = user.to_string(); - let stats_c2s = Arc::clone(&stats); - let stats_s2c = Arc::clone(&stats); + use std::sync::Arc; + use std::time::Duration; + use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; + use tokio::time::Instant; + use tracing::{debug, trace, warn}; + use crate::error::Result; + use crate::stats::Stats; + use crate::stream::BufferPool; - let c2s_bytes = Arc::new(AtomicU64::new(0)); - let s2c_bytes = Arc::new(AtomicU64::new(0)); - let c2s_bytes_clone = Arc::clone(&c2s_bytes); - let s2c_bytes_clone = Arc::clone(&s2c_bytes); + // Activity timeout for iOS compatibility (30 minutes) + const ACTIVITY_TIMEOUT_SECS: u64 = 1800; - let activity_timeout = Duration::from_secs(ACTIVITY_TIMEOUT_SECS); + /// Relay data bidirectionally between client and server. + /// + /// Uses a single-task select!-based loop instead of spawning two tasks. + /// This eliminates: + /// - 2× task spawn overhead per connection + /// - Zombie task problem (old code used select! on JoinHandles but + /// never aborted the losing task — it would run for up to 30 min) + /// - Extra Arc allocations for cross-task byte counters + /// + /// The flush()-per-write was also removed: TCP_NODELAY is set on all + /// sockets (socket.rs), so data is pushed immediately without Nagle + /// buffering. Explicit flush() on every small read was causing a + /// syscall storm and defeating CryptoWriter's internal coalescing. + pub async fn relay_bidirectional( + mut client_reader: CR, + mut client_writer: CW, + mut server_reader: SR, + mut server_writer: SW, + user: &str, + stats: Arc, + buffer_pool: Arc, + ) -> Result<()> + where + CR: AsyncRead + Unpin + Send + 'static, + CW: AsyncWrite + Unpin + Send + 'static, + SR: AsyncRead + Unpin + Send + 'static, + SW: AsyncWrite + Unpin + Send + 'static, + { + // Get buffers from pool — one per direction + let mut c2s_buf = buffer_pool.get(); + let cap = c2s_buf.capacity(); + c2s_buf.resize(cap, 0); - let pool_c2s = buffer_pool.clone(); - let pool_s2c = buffer_pool.clone(); + let mut s2c_buf = buffer_pool.get(); + let cap = s2c_buf.capacity(); + s2c_buf.resize(cap, 0); - // Client -> Server task - let c2s = tokio::spawn(async move { - // Get buffer from pool - let mut pooled_buf = pool_c2s.get(); - // CRITICAL FIX: BytesMut from pool has len 0. We must resize it to be usable as &mut [u8]. - // We use the full capacity. - let cap = pooled_buf.capacity(); - pooled_buf.resize(cap, 0); - - let mut total_bytes = 0u64; - let mut prev_total_bytes = 0u64; - let mut msg_count = 0u64; - let mut last_activity = Instant::now(); + let activity_timeout = Duration::from_secs(ACTIVITY_TIMEOUT_SECS); + + let mut c2s_total: u64 = 0; + let mut s2c_total: u64 = 0; + let mut c2s_msgs: u64 = 0; + let mut s2c_msgs: u64 = 0; + + // For periodic rate logging + let mut c2s_prev: u64 = 0; + let mut s2c_prev: u64 = 0; let mut last_log = Instant::now(); - - loop { - // Read with timeout - let read_result = tokio::time::timeout( - activity_timeout, - client_reader.read(&mut pooled_buf) - ).await; - - match read_result { - Err(_) => { - warn!( - user = %user_c2s, - total_bytes = total_bytes, - msgs = msg_count, - idle_secs = last_activity.elapsed().as_secs(), - "Activity timeout (C->S) - no data received" - ); - let _ = server_writer.shutdown().await; - break; - } - - Ok(Ok(0)) => { - debug!( - user = %user_c2s, - total_bytes = total_bytes, - msgs = msg_count, - "Client closed connection (C->S)" - ); - let _ = server_writer.shutdown().await; - break; - } - - Ok(Ok(n)) => { - total_bytes += n as u64; - msg_count += 1; - last_activity = Instant::now(); - c2s_bytes_clone.store(total_bytes, Ordering::Relaxed); - - stats_c2s.add_user_octets_from(&user_c2s, n as u64); - stats_c2s.increment_user_msgs_from(&user_c2s); - - trace!( - user = %user_c2s, - bytes = n, - total = total_bytes, - "C->S data" - ); - - // Log activity every 10 seconds with correct rate - let elapsed = last_log.elapsed(); - if elapsed > Duration::from_secs(10) { - let delta = total_bytes - prev_total_bytes; - let rate = delta as f64 / elapsed.as_secs_f64(); - - debug!( - user = %user_c2s, - total_bytes = total_bytes, - msgs = msg_count, - rate_kbps = (rate / 1024.0) as u64, - "C->S transfer in progress" - ); - - last_log = Instant::now(); - prev_total_bytes = total_bytes; - } - - if let Err(e) = server_writer.write_all(&pooled_buf[..n]).await { - debug!(user = %user_c2s, error = %e, "Failed to write to server"); - break; - } - if let Err(e) = server_writer.flush().await { - debug!(user = %user_c2s, error = %e, "Failed to flush to server"); - break; - } - } - - Ok(Err(e)) => { - debug!(user = %user_c2s, error = %e, total_bytes = total_bytes, "Client read error"); - break; - } - } - } - }); - // Server -> Client task - let s2c = tokio::spawn(async move { - // Get buffer from pool - let mut pooled_buf = pool_s2c.get(); - // CRITICAL FIX: Resize buffer - let cap = pooled_buf.capacity(); - pooled_buf.resize(cap, 0); - - let mut total_bytes = 0u64; - let mut prev_total_bytes = 0u64; - let mut msg_count = 0u64; - let mut last_activity = Instant::now(); - let mut last_log = Instant::now(); - - loop { - let read_result = tokio::time::timeout( - activity_timeout, - server_reader.read(&mut pooled_buf) - ).await; - - match read_result { - Err(_) => { - warn!( - user = %user_s2c, - total_bytes = total_bytes, - msgs = msg_count, - idle_secs = last_activity.elapsed().as_secs(), - "Activity timeout (S->C) - no data received" - ); - let _ = client_writer.shutdown().await; - break; - } - - Ok(Ok(0)) => { - debug!( - user = %user_s2c, - total_bytes = total_bytes, - msgs = msg_count, - "Server closed connection (S->C)" - ); - let _ = client_writer.shutdown().await; - break; - } - - Ok(Ok(n)) => { - total_bytes += n as u64; - msg_count += 1; - last_activity = Instant::now(); - s2c_bytes_clone.store(total_bytes, Ordering::Relaxed); - - stats_s2c.add_user_octets_to(&user_s2c, n as u64); - stats_s2c.increment_user_msgs_to(&user_s2c); - - trace!( - user = %user_s2c, - bytes = n, - total = total_bytes, - "S->C data" - ); - - let elapsed = last_log.elapsed(); - if elapsed > Duration::from_secs(10) { - let delta = total_bytes - prev_total_bytes; - let rate = delta as f64 / elapsed.as_secs_f64(); - - debug!( - user = %user_s2c, - total_bytes = total_bytes, - msgs = msg_count, - rate_kbps = (rate / 1024.0) as u64, - "S->C transfer in progress" - ); - - last_log = Instant::now(); - prev_total_bytes = total_bytes; - } - - if let Err(e) = client_writer.write_all(&pooled_buf[..n]).await { - debug!(user = %user_s2c, error = %e, "Failed to write to client"); - break; - } - if let Err(e) = client_writer.flush().await { - debug!(user = %user_s2c, error = %e, "Failed to flush to client"); - break; - } - } - - Ok(Err(e)) => { - debug!(user = %user_s2c, error = %e, total_bytes = total_bytes, "Server read error"); - break; - } - } - } - }); + let user_owned = user.to_string(); - // Wait for either direction to complete - tokio::select! { - result = c2s => { - if let Err(e) = result { - warn!(error = %e, "C->S task panicked"); - } - } - result = s2c => { - if let Err(e) = result { - warn!(error = %e, "S->C task panicked"); + loop { + tokio::select! { + biased; + + // Client -> Server direction + result = tokio::time::timeout(activity_timeout, client_reader.read(&mut c2s_buf)) => { + match result { + Err(_) => { + // Activity timeout + warn!( + user = %user_owned, + c2s_bytes = c2s_total, + s2c_bytes = s2c_total, + "Activity timeout (C->S)" + ); + break; + } + Ok(Ok(0)) => { + // Client closed + debug!( + user = %user_owned, + c2s_bytes = c2s_total, + s2c_bytes = s2c_total, + "Client closed connection" + ); + break; + } + Ok(Ok(n)) => { + c2s_total += n as u64; + c2s_msgs += 1; + + stats.add_user_octets_from(&user_owned, n as u64); + stats.increment_user_msgs_from(&user_owned); + + trace!(user = %user_owned, bytes = n, "C->S"); + + // Write without flush — TCP_NODELAY handles push + if let Err(e) = server_writer.write_all(&c2s_buf[..n]).await { + debug!(user = %user_owned, error = %e, "Write to server failed"); + break; + } + } + Ok(Err(e)) => { + debug!(user = %user_owned, error = %e, "Client read error"); + break; + } + } + } + + // Server -> Client direction + result = tokio::time::timeout(activity_timeout, server_reader.read(&mut s2c_buf)) => { + match result { + Err(_) => { + warn!( + user = %user_owned, + c2s_bytes = c2s_total, + s2c_bytes = s2c_total, + "Activity timeout (S->C)" + ); + break; + } + Ok(Ok(0)) => { + debug!( + user = %user_owned, + c2s_bytes = c2s_total, + s2c_bytes = s2c_total, + "Server closed connection" + ); + break; + } + Ok(Ok(n)) => { + s2c_total += n as u64; + s2c_msgs += 1; + + stats.add_user_octets_to(&user_owned, n as u64); + stats.increment_user_msgs_to(&user_owned); + + trace!(user = %user_owned, bytes = n, "S->C"); + + if let Err(e) = client_writer.write_all(&s2c_buf[..n]).await { + debug!(user = %user_owned, error = %e, "Write to client failed"); + break; + } + } + Ok(Err(e)) => { + debug!(user = %user_owned, error = %e, "Server read error"); + break; + } + } + } + } + + // Periodic rate logging (every 10s) + let elapsed = last_log.elapsed(); + if elapsed > Duration::from_secs(10) { + let secs = elapsed.as_secs_f64(); + let c2s_delta = c2s_total - c2s_prev; + let s2c_delta = s2c_total - s2c_prev; + + debug!( + user = %user_owned, + c2s_kbps = (c2s_delta as f64 / secs / 1024.0) as u64, + s2c_kbps = (s2c_delta as f64 / secs / 1024.0) as u64, + c2s_total = c2s_total, + s2c_total = s2c_total, + "Relay active" + ); + + c2s_prev = c2s_total; + s2c_prev = s2c_total; + last_log = Instant::now(); } } + + // Clean shutdown of both directions + let _ = server_writer.shutdown().await; + let _ = client_writer.shutdown().await; + + debug!( + user = %user_owned, + c2s_bytes = c2s_total, + s2c_bytes = s2c_total, + c2s_msgs = c2s_msgs, + s2c_msgs = s2c_msgs, + "Relay finished" + ); + + Ok(()) } - - debug!( - c2s_bytes = c2s_bytes.load(Ordering::Relaxed), - s2c_bytes = s2c_bytes.load(Ordering::Relaxed), - "Relay finished" - ); - - Ok(()) -} \ No newline at end of file + \ No newline at end of file diff --git a/src/transport/upstream.rs b/src/transport/upstream.rs index 86e6b2a..6adf452 100644 --- a/src/transport/upstream.rs +++ b/src/transport/upstream.rs @@ -66,10 +66,25 @@ impl UpstreamState { } } - /// Convert dc_idx (1-based, may be negative) to array index 0..4 + /// Map DC index to latency array slot (0..NUM_DCS). + /// + /// Matches the C implementation's `mf_cluster_lookup` behavior: + /// - Standard DCs ±1..±5 → direct mapping to array index 0..4 + /// - Unknown DCs (CDN, media, etc.) → default DC slot (index 1 = DC 2) + /// This matches Telegram's `default 2;` in proxy-multi.conf. + /// - There is NO modular arithmetic in the C implementation. fn dc_array_idx(dc_idx: i16) -> Option { - let idx = (dc_idx.unsigned_abs() as usize).checked_sub(1)?; - if idx < NUM_DCS { Some(idx) } else { None } + let abs_dc = dc_idx.unsigned_abs() as usize; + if abs_dc == 0 { + return None; + } + if abs_dc >= 1 && abs_dc <= NUM_DCS { + Some(abs_dc - 1) + } else { + // Unknown DC → default cluster (DC 2, index 1) + // Same as C: mf_cluster_lookup returns default_cluster + Some(1) + } } /// Get latency for a specific DC, falling back to average across all known DCs