diff --git a/src/main.rs b/src/main.rs index 9f0b167..860df54 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,11 +26,6 @@ use crate::transport::{create_listener, ListenOptions, UpstreamManager}; use crate::util::ip::detect_ip; use crate::stream::BufferPool; -/// Parse command-line arguments. -/// -/// Usage: telemt [config_path] [--silent] [--log-level ] -/// -/// Returns (config_path, silent_flag, log_level_override) fn parse_cli() -> (String, bool, Option) { let mut config_path = "config.toml".to_string(); let mut silent = false; @@ -40,33 +35,23 @@ fn parse_cli() -> (String, bool, Option) { let mut i = 0; while i < args.len() { match args[i].as_str() { - "--silent" | "-s" => { - silent = true; - } + "--silent" | "-s" => { silent = true; } "--log-level" => { i += 1; - if i < args.len() { - log_level = Some(args[i].clone()); - } + if i < args.len() { log_level = Some(args[i].clone()); } } s if s.starts_with("--log-level=") => { log_level = Some(s.trim_start_matches("--log-level=").to_string()); } "--help" | "-h" => { eprintln!("Usage: telemt [config.toml] [OPTIONS]"); - eprintln!(); - eprintln!("Options:"); - eprintln!(" --silent, -s Suppress info logs (only warn/error)"); - eprintln!(" --log-level Set log level: debug|verbose|normal|silent"); + eprintln!(" --silent, -s Suppress info logs"); + eprintln!(" --log-level debug|verbose|normal|silent"); eprintln!(" --help, -h Show this help"); std::process::exit(0); } - s if !s.starts_with('-') => { - config_path = s.to_string(); - } - other => { - eprintln!("Unknown option: {}", other); - } + s if !s.starts_with('-') => { config_path = s.to_string(); } + other => { eprintln!("Unknown option: {}", other); } } i += 1; } @@ -76,20 +61,17 @@ fn parse_cli() -> (String, bool, Option) { #[tokio::main] async fn main() -> Result<(), Box> { - // 1. Parse CLI arguments let (config_path, cli_silent, cli_log_level) = parse_cli(); - // 2. Load config (tracing not yet initialized — errors go to stderr) let config = match ProxyConfig::load(&config_path) { Ok(c) => c, Err(e) => { if std::path::Path::new(&config_path).exists() { - eprintln!("[telemt] Error: Failed to load config '{}': {}", config_path, e); + eprintln!("[telemt] Error: {}", e); std::process::exit(1); } else { let default = ProxyConfig::default(); - let toml_str = toml::to_string_pretty(&default).unwrap(); - std::fs::write(&config_path, toml_str).unwrap(); + std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap(); eprintln!("[telemt] Created default config at {}", config_path); default } @@ -97,80 +79,90 @@ async fn main() -> Result<(), Box> { }; if let Err(e) = config.validate() { - eprintln!("[telemt] Error: Invalid configuration: {}", e); + eprintln!("[telemt] Invalid config: {}", e); std::process::exit(1); } - // 3. Determine effective log level - // Priority: RUST_LOG env > CLI flags > config file > default (normal) let effective_log_level = if cli_silent { LogLevel::Silent - } else if let Some(ref level_str) = cli_log_level { - LogLevel::from_str_loose(level_str) + } else if let Some(ref s) = cli_log_level { + LogLevel::from_str_loose(s) } else { config.general.log_level.clone() }; - // 4. Initialize tracing let filter = if std::env::var("RUST_LOG").is_ok() { - // RUST_LOG takes absolute priority EnvFilter::from_default_env() } else { EnvFilter::new(effective_log_level.to_filter_str()) }; - fmt() - .with_env_filter(filter) - .init(); + fmt().with_env_filter(filter).init(); - // 5. Log startup info (operational — respects log level) info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION")); info!("Log level: {}", effective_log_level); - info!( - "Modes: classic={} secure={} tls={}", + info!("Modes: classic={} secure={} tls={}", config.general.modes.classic, config.general.modes.secure, - config.general.modes.tls - ); + config.general.modes.tls); info!("TLS domain: {}", config.censorship.tls_domain); - info!( - "Mask: {} -> {}:{}", + info!("Mask: {} -> {}:{}", config.censorship.mask, config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain), - config.censorship.mask_port - ); + config.censorship.mask_port); if config.censorship.tls_domain == "www.google.com" { warn!("Using default tls_domain (www.google.com). Consider setting a custom domain."); } + let prefer_ipv6 = config.general.prefer_ipv6; let config = Arc::new(config); let stats = Arc::new(Stats::new()); let rng = Arc::new(SecureRandom::new()); - // Initialize ReplayChecker let replay_checker = Arc::new(ReplayChecker::new( config.access.replay_check_len, Duration::from_secs(config.access.replay_window_secs), )); - // Initialize Upstream Manager let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone())); - - // Initialize Buffer Pool (16KB buffers, max 4096 cached ≈ 64MB) let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096)); - // Start health checks + // === Startup DC Ping === + println!("=== Telegram DC Connectivity ==="); + let ping_results = upstream_manager.ping_all_dcs(prefer_ipv6).await; + for upstream_result in &ping_results { + println!(" via {}", upstream_result.upstream_name); + for dc in &upstream_result.results { + match (&dc.rtt_ms, &dc.error) { + (Some(rtt), _) => { + println!(" DC{} ({:>21}): {:.0}ms", dc.dc_idx, dc.dc_addr, rtt); + } + (None, Some(err)) => { + println!(" DC{} ({:>21}): FAIL ({})", dc.dc_idx, dc.dc_addr, err); + } + (None, None) => { + println!(" DC{} ({:>21}): FAIL", dc.dc_idx, dc.dc_addr); + } + } + } + } + println!("================================"); + + // Start background tasks let um_clone = upstream_manager.clone(); tokio::spawn(async move { - um_clone.run_health_checks().await; + um_clone.run_health_checks(prefer_ipv6).await; + }); + + let rc_clone = replay_checker.clone(); + tokio::spawn(async move { + rc_clone.run_periodic_cleanup().await; }); - // Detect public IP (once at startup) let detected_ip = detect_ip().await; debug!("Detected IPs: v4={:?} v6={:?}", detected_ip.ipv4, detected_ip.ipv6); - // 6. Start listeners let mut listeners = Vec::new(); for listener_conf in &config.server.listeners { @@ -185,7 +177,6 @@ async fn main() -> Result<(), Box> { let listener = TcpListener::from_std(socket.into())?; info!("Listening on {}", addr); - // Determine public IP for tg:// links let public_ip = if let Some(ip) = listener_conf.announce_ip { ip } else if listener_conf.ip.is_unspecified() { @@ -198,30 +189,26 @@ async fn main() -> Result<(), Box> { listener_conf.ip }; - // 7. Print proxy links (always visible — uses println!, not tracing) if !config.show_link.is_empty() { println!("--- Proxy Links ({}) ---", public_ip); for user_name in &config.show_link { if let Some(secret) = config.access.users.get(user_name) { println!("[{}]", user_name); - if config.general.modes.classic { - println!(" Classic: tg://proxy?server={}&port={}&secret={}", + println!(" Classic: tg://proxy?server={}&port={}&secret={}", public_ip, config.server.port, secret); } - if config.general.modes.secure { - println!(" DD: tg://proxy?server={}&port={}&secret=dd{}", + println!(" DD: tg://proxy?server={}&port={}&secret=dd{}", public_ip, config.server.port, secret); } - if config.general.modes.tls { let domain_hex = hex::encode(&config.censorship.tls_domain); - println!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", + println!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", public_ip, config.server.port, secret, domain_hex); } } else { - warn!("User '{}' in show_link not found in users", user_name); + warn!("User '{}' in show_link not found", user_name); } } println!("------------------------"); @@ -236,11 +223,10 @@ async fn main() -> Result<(), Box> { } if listeners.is_empty() { - error!("No listeners could be started. Exiting."); + error!("No listeners. Exiting."); std::process::exit(1); } - // 8. Accept loop for listener in listeners { let config = config.clone(); let stats = stats.clone(); @@ -262,14 +248,8 @@ async fn main() -> Result<(), Box> { tokio::spawn(async move { if let Err(e) = ClientHandler::new( - stream, - peer_addr, - config, - stats, - upstream_manager, - replay_checker, - buffer_pool, - rng + stream, peer_addr, config, stats, + upstream_manager, replay_checker, buffer_pool, rng ).run().await { debug!(peer = %peer_addr, error = %e, "Connection error"); } @@ -284,7 +264,6 @@ async fn main() -> Result<(), Box> { }); } - // 9. Wait for shutdown signal match signal::ctrl_c().await { Ok(()) => info!("Shutting down..."), Err(e) => error!("Signal error: {}", e), diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 39002e8..fb30742 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -1,32 +1,28 @@ -//! Statistics +//! Statistics and replay protection use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use std::time::{Instant, Duration}; use dashmap::DashMap; -use parking_lot::{RwLock, Mutex}; +use parking_lot::Mutex; use lru::LruCache; use std::num::NonZeroUsize; use std::hash::{Hash, Hasher}; use std::collections::hash_map::DefaultHasher; use std::collections::VecDeque; +use tracing::debug; + +// ============= Stats ============= -/// Thread-safe statistics #[derive(Default)] pub struct Stats { - // Global counters connects_all: AtomicU64, connects_bad: AtomicU64, handshake_timeouts: AtomicU64, - - // Per-user stats user_stats: DashMap, - - // Start time - start_time: RwLock>, + start_time: parking_lot::RwLock>, } -/// Per-user statistics #[derive(Default)] pub struct UserStats { pub connects: AtomicU64, @@ -44,42 +40,20 @@ impl Stats { stats } - // Global stats - pub fn increment_connects_all(&self) { - self.connects_all.fetch_add(1, Ordering::Relaxed); - } + pub fn increment_connects_all(&self) { self.connects_all.fetch_add(1, Ordering::Relaxed); } + pub fn increment_connects_bad(&self) { self.connects_bad.fetch_add(1, Ordering::Relaxed); } + pub fn increment_handshake_timeouts(&self) { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); } + pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) } + pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) } - pub fn increment_connects_bad(&self) { - self.connects_bad.fetch_add(1, Ordering::Relaxed); - } - - pub fn increment_handshake_timeouts(&self) { - self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); - } - - pub fn get_connects_all(&self) -> u64 { - self.connects_all.load(Ordering::Relaxed) - } - - pub fn get_connects_bad(&self) -> u64 { - self.connects_bad.load(Ordering::Relaxed) - } - - // User stats pub fn increment_user_connects(&self, user: &str) { - self.user_stats - .entry(user.to_string()) - .or_default() - .connects - .fetch_add(1, Ordering::Relaxed); + self.user_stats.entry(user.to_string()).or_default() + .connects.fetch_add(1, Ordering::Relaxed); } pub fn increment_user_curr_connects(&self, user: &str) { - self.user_stats - .entry(user.to_string()) - .or_default() - .curr_connects - .fetch_add(1, Ordering::Relaxed); + self.user_stats.entry(user.to_string()).or_default() + .curr_connects.fetch_add(1, Ordering::Relaxed); } pub fn decrement_user_curr_connects(&self, user: &str) { @@ -89,47 +63,33 @@ impl Stats { } pub fn get_user_curr_connects(&self, user: &str) -> u64 { - self.user_stats - .get(user) + self.user_stats.get(user) .map(|s| s.curr_connects.load(Ordering::Relaxed)) .unwrap_or(0) } pub fn add_user_octets_from(&self, user: &str, bytes: u64) { - self.user_stats - .entry(user.to_string()) - .or_default() - .octets_from_client - .fetch_add(bytes, Ordering::Relaxed); + self.user_stats.entry(user.to_string()).or_default() + .octets_from_client.fetch_add(bytes, Ordering::Relaxed); } pub fn add_user_octets_to(&self, user: &str, bytes: u64) { - self.user_stats - .entry(user.to_string()) - .or_default() - .octets_to_client - .fetch_add(bytes, Ordering::Relaxed); + self.user_stats.entry(user.to_string()).or_default() + .octets_to_client.fetch_add(bytes, Ordering::Relaxed); } pub fn increment_user_msgs_from(&self, user: &str) { - self.user_stats - .entry(user.to_string()) - .or_default() - .msgs_from_client - .fetch_add(1, Ordering::Relaxed); + self.user_stats.entry(user.to_string()).or_default() + .msgs_from_client.fetch_add(1, Ordering::Relaxed); } pub fn increment_user_msgs_to(&self, user: &str) { - self.user_stats - .entry(user.to_string()) - .or_default() - .msgs_to_client - .fetch_add(1, Ordering::Relaxed); + self.user_stats.entry(user.to_string()).or_default() + .msgs_to_client.fetch_add(1, Ordering::Relaxed); } pub fn get_user_total_octets(&self, user: &str) -> u64 { - self.user_stats - .get(user) + self.user_stats.get(user) .map(|s| { s.octets_from_client.load(Ordering::Relaxed) + s.octets_to_client.load(Ordering::Relaxed) @@ -144,21 +104,27 @@ impl Stats { } } -/// Sharded Replay attack checker using LRU cache + sliding window -/// Uses multiple independent LRU caches to reduce lock contention +// ============= Replay Checker ============= + pub struct ReplayChecker { shards: Vec>, shard_mask: usize, window: Duration, + checks: AtomicU64, + hits: AtomicU64, + additions: AtomicU64, + cleanups: AtomicU64, } struct ReplayEntry { seen_at: Instant, + seq: u64, } struct ReplayShard { - cache: LruCache, ReplayEntry>, - queue: VecDeque<(Instant, Vec)>, + cache: LruCache, ReplayEntry>, + queue: VecDeque<(Instant, Box<[u8]>, u64)>, + seq_counter: u64, } impl ReplayShard { @@ -166,33 +132,60 @@ impl ReplayShard { Self { cache: LruCache::new(cap), queue: VecDeque::with_capacity(cap.get()), + seq_counter: 0, } } + + fn next_seq(&mut self) -> u64 { + self.seq_counter += 1; + self.seq_counter + } fn cleanup(&mut self, now: Instant, window: Duration) { if window.is_zero() { return; } - let cutoff = now - window; - while let Some((ts, _)) = self.queue.front() { + let cutoff = now.checked_sub(window).unwrap_or(now); + + while let Some((ts, _, _)) = self.queue.front() { if *ts >= cutoff { break; } - let (ts_old, key_old) = self.queue.pop_front().unwrap(); - if let Some(entry) = self.cache.get(&key_old) { - if entry.seen_at <= ts_old { - self.cache.pop(&key_old); + let (_, key, queue_seq) = self.queue.pop_front().unwrap(); + + // Use key.as_ref() to get &[u8] — avoids Borrow ambiguity + // between Borrow<[u8]> and Borrow> + if let Some(entry) = self.cache.peek(key.as_ref()) { + if entry.seq == queue_seq { + self.cache.pop(key.as_ref()); } } } } + + fn check(&mut self, key: &[u8], now: Instant, window: Duration) -> bool { + self.cleanup(now, window); + // key is &[u8], resolves Q=[u8] via Box<[u8]>: Borrow<[u8]> + self.cache.get(key).is_some() + } + + fn add(&mut self, key: &[u8], now: Instant, window: Duration) { + self.cleanup(now, window); + + let seq = self.next_seq(); + let boxed_key: Box<[u8]> = key.into(); + + self.cache.put(boxed_key.clone(), ReplayEntry { seen_at: now, seq }); + self.queue.push_back((now, boxed_key, seq)); + } + + fn len(&self) -> usize { + self.cache.len() + } } impl ReplayChecker { - /// Create new replay checker with specified capacity per shard - /// Total capacity = capacity * num_shards pub fn new(total_capacity: usize, window: Duration) -> Self { - // Use 64 shards for good concurrency let num_shards = 64; let shard_capacity = (total_capacity / num_shards).max(1); let cap = NonZeroUsize::new(shard_capacity).unwrap(); @@ -206,50 +199,114 @@ impl ReplayChecker { shards, shard_mask: num_shards - 1, window, + checks: AtomicU64::new(0), + hits: AtomicU64::new(0), + additions: AtomicU64::new(0), + cleanups: AtomicU64::new(0), } } - fn get_shard(&self, key: &[u8]) -> usize { + fn get_shard_idx(&self, key: &[u8]) -> usize { let mut hasher = DefaultHasher::new(); key.hash(&mut hasher); (hasher.finish() as usize) & self.shard_mask } fn check(&self, data: &[u8]) -> bool { - let shard_idx = self.get_shard(data); - let mut shard = self.shards[shard_idx].lock(); - let now = Instant::now(); - shard.cleanup(now, self.window); - - let key = data.to_vec(); - shard.cache.get(&key).is_some() + self.checks.fetch_add(1, Ordering::Relaxed); + let idx = self.get_shard_idx(data); + let mut shard = self.shards[idx].lock(); + let found = shard.check(data, Instant::now(), self.window); + if found { + self.hits.fetch_add(1, Ordering::Relaxed); + } + found } fn add(&self, data: &[u8]) { - let shard_idx = self.get_shard(data); - let mut shard = self.shards[shard_idx].lock(); - let now = Instant::now(); - shard.cleanup(now, self.window); - - let key = data.to_vec(); - shard.cache.put(key.clone(), ReplayEntry { seen_at: now }); - shard.queue.push_back((now, key)); + self.additions.fetch_add(1, Ordering::Relaxed); + let idx = self.get_shard_idx(data); + let mut shard = self.shards[idx].lock(); + shard.add(data, Instant::now(), self.window); } - pub fn check_handshake(&self, data: &[u8]) -> bool { - self.check(data) + pub fn check_handshake(&self, data: &[u8]) -> bool { self.check(data) } + pub fn add_handshake(&self, data: &[u8]) { self.add(data) } + pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check(data) } + pub fn add_tls_digest(&self, data: &[u8]) { self.add(data) } + + pub fn stats(&self) -> ReplayStats { + let mut total_entries = 0; + let mut total_queue_len = 0; + for shard in &self.shards { + let s = shard.lock(); + total_entries += s.cache.len(); + total_queue_len += s.queue.len(); + } + + ReplayStats { + total_entries, + total_queue_len, + total_checks: self.checks.load(Ordering::Relaxed), + total_hits: self.hits.load(Ordering::Relaxed), + total_additions: self.additions.load(Ordering::Relaxed), + total_cleanups: self.cleanups.load(Ordering::Relaxed), + num_shards: self.shards.len(), + window_secs: self.window.as_secs(), + } } - - pub fn add_handshake(&self, data: &[u8]) { - self.add(data) + + pub async fn run_periodic_cleanup(&self) { + let interval = if self.window.as_secs() > 60 { + Duration::from_secs(30) + } else { + Duration::from_secs(self.window.as_secs().max(1) / 2) + }; + + loop { + tokio::time::sleep(interval).await; + + let now = Instant::now(); + let mut cleaned = 0usize; + + for shard_mutex in &self.shards { + let mut shard = shard_mutex.lock(); + let before = shard.len(); + shard.cleanup(now, self.window); + let after = shard.len(); + cleaned += before.saturating_sub(after); + } + + self.cleanups.fetch_add(1, Ordering::Relaxed); + + if cleaned > 0 { + debug!(cleaned = cleaned, "Replay checker: periodic cleanup"); + } + } } +} - pub fn check_tls_digest(&self, data: &[u8]) -> bool { - self.check(data) +#[derive(Debug, Clone)] +pub struct ReplayStats { + pub total_entries: usize, + pub total_queue_len: usize, + pub total_checks: u64, + pub total_hits: u64, + pub total_additions: u64, + pub total_cleanups: u64, + pub num_shards: usize, + pub window_secs: u64, +} + +impl ReplayStats { + pub fn hit_rate(&self) -> f64 { + if self.total_checks == 0 { 0.0 } + else { (self.total_hits as f64 / self.total_checks as f64) * 100.0 } } - - pub fn add_tls_digest(&self, data: &[u8]) { - self.add(data) + + pub fn ghost_ratio(&self) -> f64 { + if self.total_entries == 0 { 0.0 } + else { self.total_queue_len as f64 / self.total_entries as f64 } } } @@ -260,28 +317,60 @@ mod tests { #[test] fn test_stats_shared_counters() { let stats = Arc::new(Stats::new()); - - let stats1 = Arc::clone(&stats); - let stats2 = Arc::clone(&stats); - - stats1.increment_connects_all(); - stats2.increment_connects_all(); - stats1.increment_connects_all(); - + stats.increment_connects_all(); + stats.increment_connects_all(); + stats.increment_connects_all(); assert_eq!(stats.get_connects_all(), 3); } #[test] - fn test_replay_checker_sharding() { + fn test_replay_checker_basic() { let checker = ReplayChecker::new(100, Duration::from_secs(60)); - let data1 = b"test1"; - let data2 = b"test2"; - - checker.add_handshake(data1); - assert!(checker.check_handshake(data1)); - assert!(!checker.check_handshake(data2)); - - checker.add_handshake(data2); - assert!(checker.check_handshake(data2)); + assert!(!checker.check_handshake(b"test1")); + checker.add_handshake(b"test1"); + assert!(checker.check_handshake(b"test1")); + assert!(!checker.check_handshake(b"test2")); + } + + #[test] + fn test_replay_checker_duplicate_add() { + let checker = ReplayChecker::new(100, Duration::from_secs(60)); + checker.add_handshake(b"dup"); + checker.add_handshake(b"dup"); + assert!(checker.check_handshake(b"dup")); + } + + #[test] + fn test_replay_checker_expiration() { + let checker = ReplayChecker::new(100, Duration::from_millis(50)); + checker.add_handshake(b"expire"); + assert!(checker.check_handshake(b"expire")); + std::thread::sleep(Duration::from_millis(100)); + assert!(!checker.check_handshake(b"expire")); + } + + #[test] + fn test_replay_checker_stats() { + let checker = ReplayChecker::new(100, Duration::from_secs(60)); + checker.add_handshake(b"k1"); + checker.add_handshake(b"k2"); + checker.check_handshake(b"k1"); + checker.check_handshake(b"k3"); + let stats = checker.stats(); + assert_eq!(stats.total_additions, 2); + assert_eq!(stats.total_checks, 2); + assert_eq!(stats.total_hits, 1); + } + + #[test] + fn test_replay_checker_many_keys() { + let checker = ReplayChecker::new(1000, Duration::from_secs(60)); + for i in 0..500u32 { + checker.add(&i.to_le_bytes()); + } + for i in 0..500u32 { + assert!(checker.check(&i.to_le_bytes())); + } + assert_eq!(checker.stats().total_entries, 500); } } \ No newline at end of file diff --git a/src/transport/mod.rs b/src/transport/mod.rs index bbc5302..2b507d5 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -10,4 +10,4 @@ pub use pool::ConnectionPool; pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol}; pub use socket::*; pub use socks::*; -pub use upstream::UpstreamManager; \ No newline at end of file +pub use upstream::{UpstreamManager, StartupPingResult, DcPingResult}; \ No newline at end of file diff --git a/src/transport/upstream.rs b/src/transport/upstream.rs index 6c84011..242f599 100644 --- a/src/transport/upstream.rs +++ b/src/transport/upstream.rs @@ -1,26 +1,78 @@ -//! Upstream Management +//! Upstream Management with RTT tracking and startup ping use std::net::{SocketAddr, IpAddr}; use std::sync::Arc; use std::time::Duration; use tokio::net::TcpStream; use tokio::sync::RwLock; +use tokio::time::Instant; use rand::Rng; use tracing::{debug, warn, error, info}; use crate::config::{UpstreamConfig, UpstreamType}; use crate::error::{Result, ProxyError}; +use crate::protocol::constants::{TG_DATACENTERS_V4, TG_DATACENTERS_V6, TG_DATACENTER_PORT}; use crate::transport::socket::create_outgoing_socket_bound; use crate::transport::socks::{connect_socks4, connect_socks5}; +// ============= RTT Tracking ============= + +/// Exponential moving average for latency tracking +#[derive(Debug, Clone)] +struct LatencyEma { + /// Current EMA value in milliseconds (None = no data yet) + value_ms: Option, + /// Smoothing factor (0.0 - 1.0, higher = more weight to recent) + alpha: f64, +} + +impl LatencyEma { + fn new(alpha: f64) -> Self { + Self { value_ms: None, alpha } + } + + fn update(&mut self, sample_ms: f64) { + self.value_ms = Some(match self.value_ms { + None => sample_ms, + Some(prev) => prev * (1.0 - self.alpha) + sample_ms * self.alpha, + }); + } + + fn get(&self) -> Option { + self.value_ms + } +} + +// ============= Upstream State ============= + #[derive(Debug)] struct UpstreamState { config: UpstreamConfig, healthy: bool, fails: u32, last_check: std::time::Instant, + /// Latency EMA (alpha=0.3 — moderate smoothing) + latency: LatencyEma, } +/// Result of a single DC ping +#[derive(Debug, Clone)] +pub struct DcPingResult { + pub dc_idx: usize, + pub dc_addr: SocketAddr, + pub rtt_ms: Option, + pub error: Option, +} + +/// Result of startup ping across all DCs +#[derive(Debug, Clone)] +pub struct StartupPingResult { + pub results: Vec, + pub upstream_name: String, +} + +// ============= Upstream Manager ============= + #[derive(Clone)] pub struct UpstreamManager { upstreams: Arc>>, @@ -35,6 +87,7 @@ impl UpstreamManager { healthy: true, fails: 0, last_check: std::time::Instant::now(), + latency: LatencyEma::new(0.3), }) .collect(); @@ -43,7 +96,7 @@ impl UpstreamManager { } } - /// Select an upstream using Weighted Round Robin (simplified) + /// Select an upstream using weighted selection among healthy upstreams async fn select_upstream(&self) -> Option { let upstreams = self.upstreams.read().await; if upstreams.is_empty() { @@ -57,11 +110,9 @@ impl UpstreamManager { .collect(); if healthy_indices.is_empty() { - // If all unhealthy, try any random one return Some(rand::rng().gen_range(0..upstreams.len())); } - // Weighted selection let total_weight: u32 = healthy_indices.iter() .map(|&i| upstreams[i].config.weight as u32) .sum(); @@ -92,15 +143,19 @@ impl UpstreamManager { guard[idx].config.clone() }; + let start = Instant::now(); + match self.connect_via_upstream(&upstream, target).await { Ok(stream) => { + let rtt_ms = start.elapsed().as_secs_f64() * 1000.0; let mut guard = self.upstreams.write().await; if let Some(u) = guard.get_mut(idx) { if !u.healthy { - debug!("Upstream recovered: {:?}", u.config); + debug!(rtt_ms = rtt_ms, "Upstream recovered: {:?}", u.config); } u.healthy = true; u.fails = 0; + u.latency.update(rtt_ms); } Ok(stream) }, @@ -108,10 +163,10 @@ impl UpstreamManager { let mut guard = self.upstreams.write().await; if let Some(u) = guard.get_mut(idx) { u.fails += 1; - warn!("Failed to connect via upstream {:?}: {}. Fails: {}", u.config, e, u.fails); + warn!("Upstream {:?} failed: {}. Consecutive fails: {}", u.config, e, u.fails); if u.fails > 3 { u.healthy = false; - warn!("Upstream disabled due to failures: {:?}", u.config); + warn!("Upstream marked unhealthy: {:?}", u.config); } } Err(e) @@ -145,7 +200,7 @@ impl UpstreamManager { Ok(stream) }, UpstreamType::Socks4 { address, interface, user_id } => { - info!("Connecting to target {} via SOCKS4 proxy {}", target, address); + info!("Connecting to {} via SOCKS4 {}", target, address); let proxy_addr: SocketAddr = address.parse() .map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?; @@ -174,7 +229,7 @@ impl UpstreamManager { Ok(stream) }, UpstreamType::Socks5 { address, interface, username, password } => { - info!("Connecting to target {} via SOCKS5 proxy {}", target, address); + info!("Connecting to {} via SOCKS5 {}", target, address); let proxy_addr: SocketAddr = address.parse() .map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?; @@ -205,12 +260,109 @@ impl UpstreamManager { } } - /// Background task to check health - pub async fn run_health_checks(&self) { - let check_target: SocketAddr = "149.154.167.50:443".parse().unwrap(); + // ============= Startup Ping ============= + + /// Ping all Telegram DCs through all upstreams and return results. + /// + /// Used at startup to display connectivity and latency info. + pub async fn ping_all_dcs(&self, prefer_ipv6: bool) -> Vec { + let upstreams: Vec<(usize, UpstreamConfig)> = { + let guard = self.upstreams.read().await; + guard.iter().enumerate() + .map(|(i, u)| (i, u.config.clone())) + .collect() + }; + + let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 }; + + let mut all_results = Vec::new(); + + for (upstream_idx, upstream_config) in &upstreams { + let upstream_name = match &upstream_config.upstream_type { + UpstreamType::Direct { interface } => { + format!("direct{}", interface.as_ref().map(|i| format!(" ({})", i)).unwrap_or_default()) + } + UpstreamType::Socks4 { address, .. } => format!("socks4://{}", address), + UpstreamType::Socks5 { address, .. } => format!("socks5://{}", address), + }; + + let mut dc_results = Vec::new(); + + for (dc_zero_idx, dc_ip) in datacenters.iter().enumerate() { + let dc_addr = SocketAddr::new(*dc_ip, TG_DATACENTER_PORT); + + let ping_result = tokio::time::timeout( + Duration::from_secs(5), + self.ping_single_dc(upstream_config, dc_addr) + ).await; + + let result = match ping_result { + Ok(Ok(rtt_ms)) => { + // Update latency EMA + let mut guard = self.upstreams.write().await; + if let Some(u) = guard.get_mut(*upstream_idx) { + u.latency.update(rtt_ms); + } + DcPingResult { + dc_idx: dc_zero_idx + 1, + dc_addr, + rtt_ms: Some(rtt_ms), + error: None, + } + } + Ok(Err(e)) => DcPingResult { + dc_idx: dc_zero_idx + 1, + dc_addr, + rtt_ms: None, + error: Some(e.to_string()), + }, + Err(_) => DcPingResult { + dc_idx: dc_zero_idx + 1, + dc_addr, + rtt_ms: None, + error: Some("timeout (5s)".to_string()), + }, + }; + + dc_results.push(result); + } + + all_results.push(StartupPingResult { + results: dc_results, + upstream_name, + }); + } + + all_results + } + + /// Ping a single DC: TCP connect, measure RTT, then drop. + async fn ping_single_dc(&self, config: &UpstreamConfig, target: SocketAddr) -> Result { + let start = Instant::now(); + let _stream = self.connect_via_upstream(config, target).await?; + let rtt = start.elapsed(); + Ok(rtt.as_secs_f64() * 1000.0) + } + + // ============= Health Checks ============= + + /// Background health check task. + /// + /// Every 30 seconds, pings one representative DC per upstream. + /// Measures RTT and updates health status. + pub async fn run_health_checks(&self, prefer_ipv6: bool) { + let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 }; + + // Rotate through DCs across check cycles + let mut dc_rotation = 0usize; loop { - tokio::time::sleep(Duration::from_secs(60)).await; + tokio::time::sleep(Duration::from_secs(30)).await; + + let check_dc_idx = dc_rotation % datacenters.len(); + dc_rotation += 1; + + let check_target = SocketAddr::new(datacenters[check_dc_idx], TG_DATACENTER_PORT); let count = self.upstreams.read().await.len(); for i in 0..count { @@ -219,6 +371,7 @@ impl UpstreamManager { guard[i].config.clone() }; + let start = Instant::now(); let result = tokio::time::timeout( Duration::from_secs(10), self.connect_via_upstream(&config, check_target) @@ -229,17 +382,42 @@ impl UpstreamManager { match result { Ok(Ok(_stream)) => { + let rtt_ms = start.elapsed().as_secs_f64() * 1000.0; + u.latency.update(rtt_ms); + if !u.healthy { - debug!("Upstream recovered: {:?}", u.config); + info!( + rtt_ms = format!("{:.1}", rtt_ms), + dc = check_dc_idx + 1, + "Upstream recovered: {:?}", u.config + ); } u.healthy = true; u.fails = 0; } Ok(Err(e)) => { - debug!("Health check failed for {:?}: {}", u.config, e); + u.fails += 1; + debug!( + dc = check_dc_idx + 1, + fails = u.fails, + "Health check failed for {:?}: {}", u.config, e + ); + if u.fails > 3 { + u.healthy = false; + warn!("Upstream unhealthy (health check): {:?}", u.config); + } } Err(_) => { - debug!("Health check timeout for {:?}", u.config); + u.fails += 1; + debug!( + dc = check_dc_idx + 1, + fails = u.fails, + "Health check timeout for {:?}", u.config + ); + if u.fails > 3 { + u.healthy = false; + warn!("Upstream unhealthy (timeout): {:?}", u.config); + } } } u.last_check = std::time::Instant::now();