diff --git a/Cargo.toml b/Cargo.toml index 0f3384f..07013e9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,15 +1,15 @@ [package] name = "telemt" -version = "1.0.0" -edition = "2021" -rust-version = "1.75" +version = "1.1.0" +edition = "2024" +rust-version = "1.85" [dependencies] # C libc = "0.2" # Async runtime -tokio = { version = "1.35", features = ["full", "tracing"] } +tokio = { version = "1.42", features = ["full", "tracing"] } tokio-util = { version = "0.7", features = ["codec"] } # Crypto @@ -20,41 +20,41 @@ sha2 = "0.10" sha1 = "0.10" md-5 = "0.10" hmac = "0.12" -crc32fast = "1.3" +crc32fast = "1.4" +zeroize = { version = "1.8", features = ["derive"] } # Network socket2 = { version = "0.5", features = ["all"] } -rustls = "0.22" -# Serial +# Serialization serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" toml = "0.8" # Utils -bytes = "1.5" -thiserror = "1.0" +bytes = "1.9" +thiserror = "2.0" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } parking_lot = "0.12" dashmap = "5.5" lru = "0.12" -rand = "0.8" +rand = "0.9" chrono = { version = "0.4", features = ["serde"] } hex = "0.4" -base64 = "0.21" +base64 = "0.22" url = "2.5" -regex = "1.10" -once_cell = "1.19" +regex = "1.11" crossbeam-queue = "0.3" # HTTP -reqwest = { version = "0.11", features = ["rustls-tls"], default-features = false } +reqwest = { version = "0.12", features = ["rustls-tls"], default-features = false } [dev-dependencies] tokio-test = "0.4" criterion = "0.5" proptest = "1.4" +futures = "0.3" [[bench]] name = "crypto_bench" diff --git a/config.toml b/config.toml index 45f370f..4edd318 100644 --- a/config.toml +++ b/config.toml @@ -6,8 +6,13 @@ show_link = ["hello"] [general] prefer_ipv6 = false fast_mode = true -use_middle_proxy = false -# ad_tag = "..." +use_middle_proxy = true +ad_tag = "00000000000000000000000000000000" + +# Log level: debug | verbose | normal | silent +# Can be overridden with --silent or --log-level CLI flags +# RUST_LOG env var takes absolute priority over all of these +log_level = "normal" [general.modes] classic = false @@ -39,14 +44,13 @@ client_ack = 300 # === Anti-Censorship & Masking === [censorship] -tls_domain = "petrovich.ru" +tls_domain = "google.ru" mask = true mask_port = 443 # mask_host = "petrovich.ru" # Defaults to tls_domain if not set fake_cert_len = 2048 # === Access Control & Users === -# username "hello" is used for example [access] replay_check_len = 65536 replay_window_secs = 1800 @@ -63,17 +67,13 @@ hello = "00000000000000000000000000000000" # hello = 1073741824 # 1 GB # === Upstreams & Routing === -# By default, direct connection is used, but you can add SOCKS proxy - -# Direct - Default [[upstreams]] type = "direct" enabled = true weight = 10 -# SOCKS5 # [[upstreams]] # type = "socks5" # address = "127.0.0.1:9050" # enabled = false -# weight = 1 +# weight = 1 \ No newline at end of file diff --git a/src/config/mod.rs b/src/config/mod.rs index b9eaf16..38e57c0 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -29,6 +29,58 @@ fn default_metrics_whitelist() -> Vec { ] } +// ============= Log Level ============= + +/// Logging verbosity level +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum LogLevel { + /// All messages including trace (trace + debug + info + warn + error) + Debug, + /// Detailed operational logs (debug + info + warn + error) + Verbose, + /// Standard operational logs (info + warn + error) + #[default] + Normal, + /// Minimal output: only warnings and errors (warn + error). + /// Proxy links are still printed to stdout via println!. + Silent, +} + +impl LogLevel { + /// Convert to tracing EnvFilter directive string + pub fn to_filter_str(&self) -> &'static str { + match self { + LogLevel::Debug => "trace", + LogLevel::Verbose => "debug", + LogLevel::Normal => "info", + LogLevel::Silent => "warn", + } + } + + /// Parse from a loose string (CLI argument) + pub fn from_str_loose(s: &str) -> Self { + match s.to_lowercase().as_str() { + "debug" | "trace" => LogLevel::Debug, + "verbose" => LogLevel::Verbose, + "normal" | "info" => LogLevel::Normal, + "silent" | "quiet" | "error" | "warn" => LogLevel::Silent, + _ => LogLevel::Normal, + } + } +} + +impl std::fmt::Display for LogLevel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + LogLevel::Debug => write!(f, "debug"), + LogLevel::Verbose => write!(f, "verbose"), + LogLevel::Normal => write!(f, "normal"), + LogLevel::Silent => write!(f, "silent"), + } + } +} + // ============= Sub-Configs ============= #[derive(Debug, Clone, Serialize, Deserialize)] @@ -63,6 +115,9 @@ pub struct GeneralConfig { #[serde(default)] pub ad_tag: Option, + + #[serde(default)] + pub log_level: LogLevel, } impl Default for GeneralConfig { @@ -73,6 +128,7 @@ impl Default for GeneralConfig { fast_mode: true, use_middle_proxy: false, ad_tag: None, + log_level: LogLevel::Normal, } } } @@ -304,20 +360,14 @@ impl ProxyConfig { return Err(ProxyError::Config("tls_domain cannot be empty".to_string())); } - // Warn if using default tls_domain - if config.censorship.tls_domain == "www.google.com" { - tracing::warn!("Using default tls_domain (www.google.com). Consider setting a custom domain in config.toml"); - } - // Default mask_host to tls_domain if not set if config.censorship.mask_host.is_none() { - tracing::info!("mask_host not set, using tls_domain ({}) for masking", config.censorship.tls_domain); config.censorship.mask_host = Some(config.censorship.tls_domain.clone()); } // Random fake_cert_len use rand::Rng; - config.censorship.fake_cert_len = rand::thread_rng().gen_range(1024..4096); + config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096); // Migration: Populate listeners if empty if config.server.listeners.is_empty() { @@ -358,7 +408,6 @@ impl ProxyConfig { return Err(ProxyError::Config("No modes enabled".to_string())); } - // Validate tls_domain format (basic check) if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') { return Err(ProxyError::Config( format!("Invalid tls_domain: '{}'. Must be a valid domain name", self.censorship.tls_domain) diff --git a/src/crypto/aes.rs b/src/crypto/aes.rs index 592a21f..9e123bf 100644 --- a/src/crypto/aes.rs +++ b/src/crypto/aes.rs @@ -1,9 +1,19 @@ //! AES encryption implementations //! //! Provides AES-256-CTR and AES-256-CBC modes for MTProto encryption. +//! +//! ## Zeroize policy +//! +//! - `AesCbc` stores raw key/IV bytes and zeroizes them on drop. +//! - `AesCtr` wraps an opaque `Aes256Ctr` cipher from the `ctr` crate. +//! The expanded key schedule lives inside that type and cannot be +//! zeroized from outside. Callers that hold raw key material (e.g. +//! `HandshakeSuccess`, `ObfuscationParams`) are responsible for +//! zeroizing their own copies. use aes::Aes256; use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}}; +use zeroize::Zeroize; use crate::error::{ProxyError, Result}; type Aes256Ctr = Ctr128BE; @@ -12,7 +22,12 @@ type Aes256Ctr = Ctr128BE; /// AES-256-CTR encryptor/decryptor /// -/// CTR mode is symmetric - encryption and decryption are the same operation. +/// CTR mode is symmetric — encryption and decryption are the same operation. +/// +/// **Zeroize note:** The inner `Aes256Ctr` cipher state (expanded key schedule +/// + counter) is opaque and cannot be zeroized. If you need to protect key +/// material, zeroize the `[u8; 32]` key and `u128` IV at the call site +/// before dropping them. pub struct AesCtr { cipher: Aes256Ctr, } @@ -62,14 +77,23 @@ impl AesCtr { /// AES-256-CBC cipher with proper chaining /// -/// Unlike CTR mode, CBC is NOT symmetric - encryption and decryption +/// Unlike CTR mode, CBC is NOT symmetric — encryption and decryption /// are different operations. This implementation handles CBC chaining /// correctly across multiple blocks. +/// +/// Key and IV are zeroized on drop. pub struct AesCbc { key: [u8; 32], iv: [u8; 16], } +impl Drop for AesCbc { + fn drop(&mut self) { + self.key.zeroize(); + self.iv.zeroize(); + } +} + impl AesCbc { /// AES block size const BLOCK_SIZE: usize = 16; @@ -141,17 +165,9 @@ impl AesCbc { for chunk in data.chunks(Self::BLOCK_SIZE) { let plaintext: [u8; 16] = chunk.try_into().unwrap(); - - // XOR plaintext with previous ciphertext (or IV for first block) let xored = Self::xor_blocks(&plaintext, &prev_ciphertext); - - // Encrypt the XORed block let ciphertext = self.encrypt_block(&xored, &key_schedule); - - // Save for next iteration prev_ciphertext = ciphertext; - - // Append to result result.extend_from_slice(&ciphertext); } @@ -180,17 +196,9 @@ impl AesCbc { for chunk in data.chunks(Self::BLOCK_SIZE) { let ciphertext: [u8; 16] = chunk.try_into().unwrap(); - - // Decrypt the block let decrypted = self.decrypt_block(&ciphertext, &key_schedule); - - // XOR with previous ciphertext (or IV for first block) let plaintext = Self::xor_blocks(&decrypted, &prev_ciphertext); - - // Save current ciphertext for next iteration prev_ciphertext = ciphertext; - - // Append to result result.extend_from_slice(&plaintext); } @@ -217,16 +225,13 @@ impl AesCbc { for i in (0..data.len()).step_by(Self::BLOCK_SIZE) { let block = &mut data[i..i + Self::BLOCK_SIZE]; - // XOR with previous ciphertext for j in 0..Self::BLOCK_SIZE { block[j] ^= prev_ciphertext[j]; } - // Encrypt in-place let block_array: &mut [u8; 16] = block.try_into().unwrap(); *block_array = self.encrypt_block(block_array, &key_schedule); - // Save for next iteration prev_ciphertext = *block_array; } @@ -248,26 +253,20 @@ impl AesCbc { use aes::cipher::KeyInit; let key_schedule = aes::Aes256::new((&self.key).into()); - // For in-place decryption, we need to save ciphertext blocks - // before we overwrite them let mut prev_ciphertext = self.iv; for i in (0..data.len()).step_by(Self::BLOCK_SIZE) { let block = &mut data[i..i + Self::BLOCK_SIZE]; - // Save current ciphertext before modifying let current_ciphertext: [u8; 16] = block.try_into().unwrap(); - // Decrypt in-place let block_array: &mut [u8; 16] = block.try_into().unwrap(); *block_array = self.decrypt_block(block_array, &key_schedule); - // XOR with previous ciphertext for j in 0..Self::BLOCK_SIZE { block[j] ^= prev_ciphertext[j]; } - // Save for next iteration prev_ciphertext = current_ciphertext; } @@ -347,10 +346,8 @@ mod tests { let mut cipher = AesCtr::new(&key, iv); cipher.apply(&mut data); - // Encrypted should be different assert_ne!(&data[..], original); - // Decrypt with fresh cipher let mut cipher = AesCtr::new(&key, iv); cipher.apply(&mut data); @@ -364,7 +361,7 @@ mod tests { let key = [0u8; 32]; let iv = [0u8; 16]; - let original = [0u8; 32]; // 2 blocks + let original = [0u8; 32]; let cipher = AesCbc::new(key, iv); let encrypted = cipher.encrypt(&original).unwrap(); @@ -375,31 +372,25 @@ mod tests { #[test] fn test_aes_cbc_chaining_works() { - // This is the key test - verify CBC chaining is correct let key = [0x42u8; 32]; let iv = [0x00u8; 16]; - // Two IDENTICAL plaintext blocks let plaintext = [0xAAu8; 32]; let cipher = AesCbc::new(key, iv); let ciphertext = cipher.encrypt(&plaintext).unwrap(); - // With proper CBC, identical plaintext blocks produce DIFFERENT ciphertext let block1 = &ciphertext[0..16]; let block2 = &ciphertext[16..32]; assert_ne!( block1, block2, - "CBC chaining broken: identical plaintext blocks produced identical ciphertext. \ - This indicates ECB mode, not CBC!" + "CBC chaining broken: identical plaintext blocks produced identical ciphertext" ); } #[test] fn test_aes_cbc_known_vector() { - // Test with known NIST test vector - // AES-256-CBC with zero key and zero IV let key = [0u8; 32]; let iv = [0u8; 16]; let plaintext = [0u8; 16]; @@ -407,11 +398,9 @@ mod tests { let cipher = AesCbc::new(key, iv); let ciphertext = cipher.encrypt(&plaintext).unwrap(); - // Decrypt and verify roundtrip let decrypted = cipher.decrypt(&ciphertext).unwrap(); assert_eq!(plaintext.as_slice(), decrypted.as_slice()); - // Ciphertext should not be all zeros assert_ne!(ciphertext.as_slice(), plaintext.as_slice()); } @@ -420,7 +409,6 @@ mod tests { let key = [0x12u8; 32]; let iv = [0x34u8; 16]; - // 5 blocks = 80 bytes let plaintext: Vec = (0..80).collect(); let cipher = AesCbc::new(key, iv); @@ -435,7 +423,7 @@ mod tests { let key = [0x12u8; 32]; let iv = [0x34u8; 16]; - let original = [0x56u8; 48]; // 3 blocks + let original = [0x56u8; 48]; let mut buffer = original; let cipher = AesCbc::new(key, iv); @@ -462,41 +450,33 @@ mod tests { fn test_aes_cbc_unaligned_error() { let cipher = AesCbc::new([0u8; 32], [0u8; 16]); - // 15 bytes - not aligned to block size let result = cipher.encrypt(&[0u8; 15]); assert!(result.is_err()); - // 17 bytes - not aligned let result = cipher.encrypt(&[0u8; 17]); assert!(result.is_err()); } #[test] fn test_aes_cbc_avalanche_effect() { - // Changing one bit in plaintext should change entire ciphertext block - // and all subsequent blocks (due to chaining) let key = [0xAB; 32]; let iv = [0xCD; 16]; - let mut plaintext1 = [0u8; 32]; + let plaintext1 = [0u8; 32]; let mut plaintext2 = [0u8; 32]; - plaintext2[0] = 0x01; // Single bit difference in first block + plaintext2[0] = 0x01; let cipher = AesCbc::new(key, iv); let ciphertext1 = cipher.encrypt(&plaintext1).unwrap(); let ciphertext2 = cipher.encrypt(&plaintext2).unwrap(); - // First blocks should be different assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]); - - // Second blocks should ALSO be different (chaining effect) assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]); } #[test] fn test_aes_cbc_iv_matters() { - // Same plaintext with different IVs should produce different ciphertext let key = [0x55; 32]; let plaintext = [0x77u8; 16]; @@ -511,7 +491,6 @@ mod tests { #[test] fn test_aes_cbc_deterministic() { - // Same key, IV, plaintext should always produce same ciphertext let key = [0x99; 32]; let iv = [0x88; 16]; let plaintext = [0x77u8; 32]; @@ -524,6 +503,23 @@ mod tests { assert_eq!(ciphertext1, ciphertext2); } + // ============= Zeroize Tests ============= + + #[test] + fn test_aes_cbc_zeroize_on_drop() { + let key = [0xAA; 32]; + let iv = [0xBB; 16]; + + let cipher = AesCbc::new(key, iv); + // Verify key/iv are set + assert_eq!(cipher.key, [0xAA; 32]); + assert_eq!(cipher.iv, [0xBB; 16]); + + drop(cipher); + // After drop, key/iv are zeroized (can't observe directly, + // but the Drop impl runs without panic) + } + // ============= Error Handling Tests ============= #[test] diff --git a/src/crypto/hash.rs b/src/crypto/hash.rs index 0472018..cf3ba0d 100644 --- a/src/crypto/hash.rs +++ b/src/crypto/hash.rs @@ -1,3 +1,16 @@ +//! Cryptographic hash functions +//! +//! ## Protocol-required algorithms +//! +//! This module exposes MD5 and SHA-1 alongside SHA-256. These weaker +//! hash functions are **required by the Telegram Middle Proxy protocol** +//! (`derive_middleproxy_keys`) and cannot be replaced without breaking +//! compatibility. They are NOT used for any security-sensitive purpose +//! outside of that specific key derivation scheme mandated by Telegram. +//! +//! Static analysis tools (CodeQL, cargo-audit) may flag them — the +//! usages are intentional and protocol-mandated. + use hmac::{Hmac, Mac}; use sha2::Sha256; use md5::Md5; @@ -21,14 +34,16 @@ pub fn sha256_hmac(key: &[u8], data: &[u8]) -> [u8; 32] { mac.finalize().into_bytes().into() } -/// SHA-1 +/// SHA-1 — **protocol-required** by Telegram Middle Proxy key derivation. +/// Not used for general-purpose hashing. pub fn sha1(data: &[u8]) -> [u8; 20] { let mut hasher = Sha1::new(); hasher.update(data); hasher.finalize().into() } -/// MD5 +/// MD5 — **protocol-required** by Telegram Middle Proxy key derivation. +/// Not used for general-purpose hashing. pub fn md5(data: &[u8]) -> [u8; 16] { let mut hasher = Md5::new(); hasher.update(data); @@ -40,7 +55,11 @@ pub fn crc32(data: &[u8]) -> u32 { crc32fast::hash(data) } -/// Middle Proxy Keygen +/// Middle Proxy key derivation +/// +/// Uses MD5 + SHA-1 as mandated by the Telegram Middle Proxy protocol. +/// These algorithms are NOT replaceable here — changing them would break +/// interoperability with Telegram's middle proxy infrastructure. pub fn derive_middleproxy_keys( nonce_srv: &[u8; 16], nonce_clt: &[u8; 16], diff --git a/src/crypto/random.rs b/src/crypto/random.rs index 19d8788..18862ab 100644 --- a/src/crypto/random.rs +++ b/src/crypto/random.rs @@ -3,7 +3,9 @@ use rand::{Rng, RngCore, SeedableRng}; use rand::rngs::StdRng; use parking_lot::Mutex; +use zeroize::Zeroize; use crate::crypto::AesCtr; + /// Cryptographically secure PRNG with AES-CTR pub struct SecureRandom { inner: Mutex, @@ -15,18 +17,30 @@ struct SecureRandomInner { buffer: Vec, } +impl Drop for SecureRandomInner { + fn drop(&mut self) { + self.buffer.zeroize(); + } +} + impl SecureRandom { pub fn new() -> Self { - let mut rng = StdRng::from_entropy(); + let mut seed_source = rand::rng(); + let mut rng = StdRng::from_rng(&mut seed_source); let mut key = [0u8; 32]; rng.fill_bytes(&mut key); - let iv: u128 = rng.gen(); + let iv: u128 = rng.random(); + + let cipher = AesCtr::new(&key, iv); + + // Zeroize local key copy — cipher already consumed it + key.zeroize(); Self { inner: Mutex::new(SecureRandomInner { rng, - cipher: AesCtr::new(&key, iv), + cipher, buffer: Vec::with_capacity(1024), }), } @@ -73,7 +87,6 @@ impl SecureRandom { result |= (b as u64) << (i * 8); } - // Mask extra bits if k < 64 { result &= (1u64 << k) - 1; } @@ -102,13 +115,13 @@ impl SecureRandom { /// Generate random u32 pub fn u32(&self) -> u32 { let mut inner = self.inner.lock(); - inner.rng.gen() + inner.rng.random() } /// Generate random u64 pub fn u64(&self) -> u64 { let mut inner = self.inner.lock(); - inner.rng.gen() + inner.rng.random() } } @@ -157,12 +170,10 @@ mod tests { fn test_bits() { let rng = SecureRandom::new(); - // Single bit should be 0 or 1 for _ in 0..100 { assert!(rng.bits(1) <= 1); } - // 8 bits should be 0-255 for _ in 0..100 { assert!(rng.bits(8) <= 255); } @@ -180,10 +191,8 @@ mod tests { } } - // Should have seen all items assert_eq!(seen.len(), 5); - // Empty slice should return None let empty: Vec = vec![]; assert!(rng.choose(&empty).is_none()); } @@ -196,12 +205,10 @@ mod tests { let mut shuffled = original.clone(); rng.shuffle(&mut shuffled); - // Should contain same elements let mut sorted = shuffled.clone(); sorted.sort(); assert_eq!(sorted, original); - // Should be different order (with very high probability) assert_ne!(shuffled, original); } } \ No newline at end of file diff --git a/src/error.rs b/src/error.rs index 6d9bae3..f934672 100644 --- a/src/error.rs +++ b/src/error.rs @@ -118,16 +118,13 @@ pub trait Recoverable { impl Recoverable for StreamError { fn is_recoverable(&self) -> bool { match self { - // Partial operations can be retried Self::PartialRead { .. } | Self::PartialWrite { .. } => true, - // I/O errors depend on kind Self::Io(e) => matches!( e.kind(), std::io::ErrorKind::WouldBlock | std::io::ErrorKind::Interrupted | std::io::ErrorKind::TimedOut ), - // These are not recoverable Self::Poisoned { .. } | Self::BufferOverflow { .. } | Self::InvalidFrame { .. } @@ -137,13 +134,9 @@ impl Recoverable for StreamError { fn can_continue(&self) -> bool { match self { - // Poisoned stream cannot be used Self::Poisoned { .. } => false, - // EOF means stream is done Self::UnexpectedEof => false, - // Buffer overflow is fatal Self::BufferOverflow { .. } => false, - // Others might allow continuation _ => true, } } @@ -383,18 +376,18 @@ mod tests { #[test] fn test_handshake_result() { - let success: HandshakeResult = HandshakeResult::Success(42); + let success: HandshakeResult = HandshakeResult::Success(42); assert!(success.is_success()); assert!(!success.is_bad_client()); - let bad: HandshakeResult = HandshakeResult::BadClient; + let bad: HandshakeResult = HandshakeResult::BadClient { reader: (), writer: () }; assert!(!bad.is_success()); assert!(bad.is_bad_client()); } #[test] fn test_handshake_result_map() { - let success: HandshakeResult = HandshakeResult::Success(42); + let success: HandshakeResult = HandshakeResult::Success(42); let mapped = success.map(|x| x * 2); match mapped { diff --git a/src/main.rs b/src/main.rs index f87bec5..9f0b167 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use std::time::Duration; use tokio::net::TcpListener; use tokio::signal; -use tracing::{info, error, warn}; +use tracing::{info, error, warn, debug}; use tracing_subscriber::{fmt, EnvFilter}; mod config; @@ -18,7 +18,7 @@ mod stream; mod transport; mod util; -use crate::config::ProxyConfig; +use crate::config::{ProxyConfig, LogLevel}; use crate::proxy::ClientHandler; use crate::stats::{Stats, ReplayChecker}; use crate::crypto::SecureRandom; @@ -26,53 +26,129 @@ use crate::transport::{create_listener, ListenOptions, UpstreamManager}; use crate::util::ip::detect_ip; use crate::stream::BufferPool; +/// Parse command-line arguments. +/// +/// Usage: telemt [config_path] [--silent] [--log-level ] +/// +/// Returns (config_path, silent_flag, log_level_override) +fn parse_cli() -> (String, bool, Option) { + let mut config_path = "config.toml".to_string(); + let mut silent = false; + let mut log_level: Option = None; + + let args: Vec = std::env::args().skip(1).collect(); + let mut i = 0; + while i < args.len() { + match args[i].as_str() { + "--silent" | "-s" => { + silent = true; + } + "--log-level" => { + i += 1; + if i < args.len() { + log_level = Some(args[i].clone()); + } + } + s if s.starts_with("--log-level=") => { + log_level = Some(s.trim_start_matches("--log-level=").to_string()); + } + "--help" | "-h" => { + eprintln!("Usage: telemt [config.toml] [OPTIONS]"); + eprintln!(); + eprintln!("Options:"); + eprintln!(" --silent, -s Suppress info logs (only warn/error)"); + eprintln!(" --log-level Set log level: debug|verbose|normal|silent"); + eprintln!(" --help, -h Show this help"); + std::process::exit(0); + } + s if !s.starts_with('-') => { + config_path = s.to_string(); + } + other => { + eprintln!("Unknown option: {}", other); + } + } + i += 1; + } + + (config_path, silent, log_level) +} + #[tokio::main] async fn main() -> Result<(), Box> { - // Initialize logging - fmt() - .with_env_filter(EnvFilter::from_default_env().add_directive("info".parse().unwrap())) - .init(); + // 1. Parse CLI arguments + let (config_path, cli_silent, cli_log_level) = parse_cli(); - // Load config - let config_path = std::env::args().nth(1).unwrap_or_else(|| "config.toml".to_string()); + // 2. Load config (tracing not yet initialized — errors go to stderr) let config = match ProxyConfig::load(&config_path) { Ok(c) => c, Err(e) => { - // If config doesn't exist, try to create default if std::path::Path::new(&config_path).exists() { - error!("Failed to load config: {}", e); + eprintln!("[telemt] Error: Failed to load config '{}': {}", config_path, e); std::process::exit(1); } else { let default = ProxyConfig::default(); - let toml = toml::to_string_pretty(&default).unwrap(); - std::fs::write(&config_path, toml).unwrap(); - info!("Created default config at {}", config_path); + let toml_str = toml::to_string_pretty(&default).unwrap(); + std::fs::write(&config_path, toml_str).unwrap(); + eprintln!("[telemt] Created default config at {}", config_path); default } } }; - config.validate()?; + if let Err(e) = config.validate() { + eprintln!("[telemt] Error: Invalid configuration: {}", e); + std::process::exit(1); + } + + // 3. Determine effective log level + // Priority: RUST_LOG env > CLI flags > config file > default (normal) + let effective_log_level = if cli_silent { + LogLevel::Silent + } else if let Some(ref level_str) = cli_log_level { + LogLevel::from_str_loose(level_str) + } else { + config.general.log_level.clone() + }; - // Log loaded configuration for debugging - info!("=== Configuration Loaded ==="); - info!("TLS Domain: {}", config.censorship.tls_domain); - info!("Mask enabled: {}", config.censorship.mask); - info!("Mask host: {}", config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain)); - info!("Mask port: {}", config.censorship.mask_port); - info!("Modes: classic={}, secure={}, tls={}", - config.general.modes.classic, - config.general.modes.secure, + // 4. Initialize tracing + let filter = if std::env::var("RUST_LOG").is_ok() { + // RUST_LOG takes absolute priority + EnvFilter::from_default_env() + } else { + EnvFilter::new(effective_log_level.to_filter_str()) + }; + + fmt() + .with_env_filter(filter) + .init(); + + // 5. Log startup info (operational — respects log level) + info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION")); + info!("Log level: {}", effective_log_level); + info!( + "Modes: classic={} secure={} tls={}", + config.general.modes.classic, + config.general.modes.secure, config.general.modes.tls ); - info!("============================"); + info!("TLS domain: {}", config.censorship.tls_domain); + info!( + "Mask: {} -> {}:{}", + config.censorship.mask, + config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain), + config.censorship.mask_port + ); + + if config.censorship.tls_domain == "www.google.com" { + warn!("Using default tls_domain (www.google.com). Consider setting a custom domain."); + } let config = Arc::new(config); let stats = Arc::new(Stats::new()); let rng = Arc::new(SecureRandom::new()); - // Initialize global ReplayChecker - // Using sharded implementation for better concurrency + // Initialize ReplayChecker let replay_checker = Arc::new(ReplayChecker::new( config.access.replay_check_len, Duration::from_secs(config.access.replay_window_secs), @@ -81,20 +157,20 @@ async fn main() -> Result<(), Box> { // Initialize Upstream Manager let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone())); - // Initialize Buffer Pool - // 16KB buffers, max 4096 buffers (~64MB total cached) + // Initialize Buffer Pool (16KB buffers, max 4096 cached ≈ 64MB) let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096)); - // Start Health Checks + // Start health checks let um_clone = upstream_manager.clone(); tokio::spawn(async move { um_clone.run_health_checks().await; }); - // Detect public IP if needed (once at startup) + // Detect public IP (once at startup) let detected_ip = detect_ip().await; + debug!("Detected IPs: v4={:?} v6={:?}", detected_ip.ipv4, detected_ip.ipv6); - // Start Listeners + // 6. Start listeners let mut listeners = Vec::new(); for listener_conf in &config.server.listeners { @@ -122,33 +198,33 @@ async fn main() -> Result<(), Box> { listener_conf.ip }; - // Show links for configured users + // 7. Print proxy links (always visible — uses println!, not tracing) if !config.show_link.is_empty() { - info!("--- Proxy Links for {} ---", public_ip); + println!("--- Proxy Links ({}) ---", public_ip); for user_name in &config.show_link { if let Some(secret) = config.access.users.get(user_name) { - info!("User: {}", user_name); + println!("[{}]", user_name); if config.general.modes.classic { - info!(" Classic: tg://proxy?server={}&port={}&secret={}", + println!(" Classic: tg://proxy?server={}&port={}&secret={}", public_ip, config.server.port, secret); } if config.general.modes.secure { - info!(" DD: tg://proxy?server={}&port={}&secret=dd{}", + println!(" DD: tg://proxy?server={}&port={}&secret=dd{}", public_ip, config.server.port, secret); } if config.general.modes.tls { let domain_hex = hex::encode(&config.censorship.tls_domain); - info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", + println!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", public_ip, config.server.port, secret, domain_hex); } } else { - warn!("User '{}' specified in show_link not found in users list", user_name); + warn!("User '{}' in show_link not found in users", user_name); } } - info!("-----------------------------------"); + println!("------------------------"); } listeners.push(listener); @@ -164,7 +240,7 @@ async fn main() -> Result<(), Box> { std::process::exit(1); } - // Accept loop + // 8. Accept loop for listener in listeners { let config = config.clone(); let stats = stats.clone(); @@ -195,7 +271,7 @@ async fn main() -> Result<(), Box> { buffer_pool, rng ).run().await { - // Log only relevant errors + debug!(peer = %peer_addr, error = %e, "Connection error"); } }); } @@ -208,7 +284,7 @@ async fn main() -> Result<(), Box> { }); } - // Wait for signal + // 9. Wait for shutdown signal match signal::ctrl_c().await { Ok(()) => info!("Shutting down..."), Err(e) => error!("Signal error: {}", e), diff --git a/src/protocol/constants.rs b/src/protocol/constants.rs index d09473e..7451c83 100644 --- a/src/protocol/constants.rs +++ b/src/protocol/constants.rs @@ -1,13 +1,13 @@ //! Protocol constants and datacenter addresses use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -use once_cell::sync::Lazy; +use std::sync::LazyLock; // ============= Telegram Datacenters ============= pub const TG_DATACENTER_PORT: u16 = 443; -pub static TG_DATACENTERS_V4: Lazy> = Lazy::new(|| { +pub static TG_DATACENTERS_V4: LazyLock> = LazyLock::new(|| { vec![ IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), IpAddr::V4(Ipv4Addr::new(149, 154, 167, 51)), @@ -17,7 +17,7 @@ pub static TG_DATACENTERS_V4: Lazy> = Lazy::new(|| { ] }); -pub static TG_DATACENTERS_V6: Lazy> = Lazy::new(|| { +pub static TG_DATACENTERS_V6: LazyLock> = LazyLock::new(|| { vec![ IpAddr::V6("2001:b28:f23d:f001::a".parse().unwrap()), IpAddr::V6("2001:67c:04e8:f002::a".parse().unwrap()), @@ -29,8 +29,8 @@ pub static TG_DATACENTERS_V6: Lazy> = Lazy::new(|| { // ============= Middle Proxies (for advertising) ============= -pub static TG_MIDDLE_PROXIES_V4: Lazy>> = - Lazy::new(|| { +pub static TG_MIDDLE_PROXIES_V4: LazyLock>> = + LazyLock::new(|| { let mut m = std::collections::HashMap::new(); m.insert(1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]); m.insert(-1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]); @@ -45,8 +45,8 @@ pub static TG_MIDDLE_PROXIES_V4: Lazy>> = - Lazy::new(|| { +pub static TG_MIDDLE_PROXIES_V6: LazyLock>> = + LazyLock::new(|| { let mut m = std::collections::HashMap::new(); m.insert(1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]); m.insert(-1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]); @@ -167,8 +167,6 @@ pub const DEFAULT_ACK_TIMEOUT_SECS: u64 = 300; // ============= Buffer Sizes ============= /// Default buffer size -/// CHANGED: Reduced from 64KB to 16KB to match TLS record size and align with -/// the new buffering strategy for better iOS upload performance. pub const DEFAULT_BUFFER_SIZE: usize = 16384; /// Small buffer size for bad client handling diff --git a/src/protocol/obfuscation.rs b/src/protocol/obfuscation.rs index 4e09942..1c55c5f 100644 --- a/src/protocol/obfuscation.rs +++ b/src/protocol/obfuscation.rs @@ -1,10 +1,13 @@ //! MTProto Obfuscation +use zeroize::Zeroize; use crate::crypto::{sha256, AesCtr}; use crate::error::Result; use super::constants::*; /// Obfuscation parameters from handshake +/// +/// Key material is zeroized on drop. #[derive(Debug, Clone)] pub struct ObfuscationParams { /// Key for decrypting client -> proxy traffic @@ -21,25 +24,31 @@ pub struct ObfuscationParams { pub dc_idx: i16, } +impl Drop for ObfuscationParams { + fn drop(&mut self) { + self.decrypt_key.zeroize(); + self.decrypt_iv.zeroize(); + self.encrypt_key.zeroize(); + self.encrypt_iv.zeroize(); + } +} + impl ObfuscationParams { /// Parse obfuscation parameters from handshake bytes /// Returns None if handshake doesn't match any user secret pub fn from_handshake( handshake: &[u8; HANDSHAKE_LEN], - secrets: &[(String, Vec)], // (username, secret_bytes) + secrets: &[(String, Vec)], ) -> Option<(Self, String)> { - // Extract prekey and IV for decryption let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; - // Reversed for encryption direction let enc_prekey_iv: Vec = dec_prekey_iv.iter().rev().copied().collect(); let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; for (username, secret) in secrets { - // Derive decryption key let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); dec_key_input.extend_from_slice(dec_prekey); dec_key_input.extend_from_slice(secret); @@ -47,26 +56,22 @@ impl ObfuscationParams { let decrypt_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap()); - // Create decryptor and decrypt handshake let mut decryptor = AesCtr::new(&decrypt_key, decrypt_iv); let decrypted = decryptor.decrypt(handshake); - // Check protocol tag let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] .try_into() .unwrap(); let proto_tag = match ProtoTag::from_bytes(tag_bytes) { Some(tag) => tag, - None => continue, // Try next secret + None => continue, }; - // Extract DC index let dc_idx = i16::from_le_bytes( decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap() ); - // Derive encryption key let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); enc_key_input.extend_from_slice(enc_prekey); enc_key_input.extend_from_slice(secret); @@ -123,18 +128,15 @@ pub fn generate_nonce Vec>(mut random_bytes: R) -> [u8; H /// Check if nonce is valid (not matching reserved patterns) pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool { - // Check first byte if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { return false; } - // Check first 4 bytes let first_four: [u8; 4] = nonce[..4].try_into().unwrap(); if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { return false; } - // Check bytes 4-7 let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap(); if RESERVED_NONCE_CONTINUES.contains(&continue_four) { return false; @@ -147,12 +149,10 @@ pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool { pub fn prepare_tg_nonce( nonce: &mut [u8; HANDSHAKE_LEN], proto_tag: ProtoTag, - enc_key_iv: Option<&[u8]>, // For fast mode + enc_key_iv: Option<&[u8]>, ) { - // Set protocol tag nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); - // For fast mode, copy the reversed enc_key_iv if let Some(key_iv) = enc_key_iv { let reversed: Vec = key_iv.iter().rev().copied().collect(); nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&reversed); @@ -161,14 +161,12 @@ pub fn prepare_tg_nonce( /// Encrypt the outgoing nonce for Telegram pub fn encrypt_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec { - // Derive encryption key from the nonce itself let key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let enc_key = sha256(key_iv); let enc_iv = u128::from_be_bytes(key_iv[..IV_LEN].try_into().unwrap()); let mut encryptor = AesCtr::new(&enc_key, enc_iv); - // Only encrypt from PROTO_TAG_POS onwards let mut result = nonce.to_vec(); let encrypted_part = encryptor.encrypt(&nonce[PROTO_TAG_POS..]); result[PROTO_TAG_POS..].copy_from_slice(&encrypted_part); @@ -182,22 +180,18 @@ mod tests { #[test] fn test_is_valid_nonce() { - // Valid nonce let mut valid = [0x42u8; HANDSHAKE_LEN]; valid[4..8].copy_from_slice(&[1, 2, 3, 4]); assert!(is_valid_nonce(&valid)); - // Invalid: starts with 0xef let mut invalid = [0x00u8; HANDSHAKE_LEN]; invalid[0] = 0xef; assert!(!is_valid_nonce(&invalid)); - // Invalid: starts with HEAD let mut invalid = [0x00u8; HANDSHAKE_LEN]; invalid[..4].copy_from_slice(b"HEAD"); assert!(!is_valid_nonce(&invalid)); - // Invalid: bytes 4-7 are zeros let mut invalid = [0x42u8; HANDSHAKE_LEN]; invalid[4..8].copy_from_slice(&[0, 0, 0, 0]); assert!(!is_valid_nonce(&invalid)); diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 7e65716..81814e2 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -1,8 +1,9 @@ -//! MTProto Handshake Magics +//! MTProto Handshake use std::net::SocketAddr; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, warn, trace, info}; +use zeroize::Zeroize; use crate::crypto::{sha256, AesCtr, SecureRandom}; use crate::protocol::constants::*; @@ -13,6 +14,9 @@ use crate::stats::ReplayChecker; use crate::config::ProxyConfig; /// Result of successful handshake +/// +/// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is +/// zeroized on drop. #[derive(Debug, Clone)] pub struct HandshakeSuccess { /// Authenticated user name @@ -33,6 +37,15 @@ pub struct HandshakeSuccess { pub is_tls: bool, } +impl Drop for HandshakeSuccess { + fn drop(&mut self) { + self.dec_key.zeroize(); + self.dec_iv.zeroize(); + self.enc_key.zeroize(); + self.enc_iv.zeroize(); + } +} + /// Handle fake TLS handshake pub async fn handle_tls_handshake( handshake: &[u8], @@ -49,30 +62,25 @@ where { debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake"); - // Check minimum length if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { debug!(peer = %peer, "TLS handshake too short"); return HandshakeResult::BadClient { reader, writer }; } - // Extract digest for replay check let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]; let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN]; - // Check for replay if replay_checker.check_tls_digest(digest_half) { warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); return HandshakeResult::BadClient { reader, writer }; } - // Build secrets list let secrets: Vec<(String, Vec)> = config.access.users.iter() .filter_map(|(name, hex)| { hex::decode(hex).ok().map(|bytes| (name.clone(), bytes)) }) .collect(); - // Validate handshake let validation = match tls::validate_tls_handshake( handshake, &secrets, @@ -89,13 +97,11 @@ where } }; - // Get secret for response let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { Some((_, s)) => s, None => return HandshakeResult::BadClient { reader, writer }, }; - // Build and send response let response = tls::build_server_hello( secret, &validation.digest, @@ -116,7 +122,6 @@ where return HandshakeResult::Error(ProxyError::Io(e)); } - // Record for replay protection only after successful handshake replay_checker.add_tls_digest(digest_half); info!( @@ -148,26 +153,21 @@ where { trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes"); - // Extract prekey and IV let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; - // Check for replay if replay_checker.check_handshake(dec_prekey_iv) { warn!(peer = %peer, "MTProto replay attack detected"); return HandshakeResult::BadClient { reader, writer }; } - // Reversed for encryption direction let enc_prekey_iv: Vec = dec_prekey_iv.iter().rev().copied().collect(); - // Try each user's secret for (user, secret_hex) in &config.access.users { let secret = match hex::decode(secret_hex) { Ok(s) => s, Err(_) => continue, }; - // Derive decryption key let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; @@ -178,11 +178,9 @@ where let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap()); - // Decrypt handshake to check protocol tag let mut decryptor = AesCtr::new(&dec_key, dec_iv); let decrypted = decryptor.decrypt(handshake); - // Check protocol tag let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] .try_into() .unwrap(); @@ -192,7 +190,6 @@ where None => continue, }; - // Check if mode is enabled let mode_ok = match proto_tag { ProtoTag::Secure => { if is_tls { config.general.modes.tls } else { config.general.modes.secure } @@ -205,12 +202,10 @@ where continue; } - // Extract DC index let dc_idx = i16::from_le_bytes( decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap() ); - // Derive encryption key let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; @@ -221,10 +216,8 @@ where let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap()); - // Record for replay protection replay_checker.add_handshake(dec_prekey_iv); - // Create new cipher instances let decryptor = AesCtr::new(&dec_key, dec_iv); let encryptor = AesCtr::new(&enc_key, enc_iv); @@ -326,13 +319,11 @@ mod tests { let client_dec_iv = 12345u128; let rng = SecureRandom::new(); - let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = + let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, &rng, false); - // Check length assert_eq!(nonce.len(), HANDSHAKE_LEN); - // Check proto tag is set let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap(); assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure)); } @@ -349,11 +340,28 @@ mod tests { let encrypted = encrypt_tg_nonce(&nonce); assert_eq!(encrypted.len(), HANDSHAKE_LEN); - - // First PROTO_TAG_POS bytes should be unchanged assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]); - - // Rest should be different (encrypted) assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]); } + + #[test] + fn test_handshake_success_zeroize_on_drop() { + let success = HandshakeSuccess { + user: "test".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Secure, + dec_key: [0xAA; 32], + dec_iv: 0xBBBBBBBB, + enc_key: [0xCC; 32], + enc_iv: 0xDDDDDDDD, + peer: "127.0.0.1:1234".parse().unwrap(), + is_tls: true, + }; + + assert_eq!(success.dec_key, [0xAA; 32]); + assert_eq!(success.enc_key, [0xCC; 32]); + + drop(success); + // Drop impl zeroizes key material without panic + } } \ No newline at end of file diff --git a/src/transport/upstream.rs b/src/transport/upstream.rs index 4a68830..6c84011 100644 --- a/src/transport/upstream.rs +++ b/src/transport/upstream.rs @@ -32,7 +32,7 @@ impl UpstreamManager { .filter(|c| c.enabled) .map(|c| UpstreamState { config: c, - healthy: true, // Optimistic start + healthy: true, fails: 0, last_check: std::time::Instant::now(), }) @@ -58,7 +58,7 @@ impl UpstreamManager { if healthy_indices.is_empty() { // If all unhealthy, try any random one - return Some(rand::thread_rng().gen_range(0..upstreams.len())); + return Some(rand::rng().gen_range(0..upstreams.len())); } // Weighted selection @@ -67,10 +67,10 @@ impl UpstreamManager { .sum(); if total_weight == 0 { - return Some(healthy_indices[rand::thread_rng().gen_range(0..healthy_indices.len())]); + return Some(healthy_indices[rand::rng().gen_range(0..healthy_indices.len())]); } - let mut choice = rand::thread_rng().gen_range(0..total_weight); + let mut choice = rand::rng().gen_range(0..total_weight); for &idx in &healthy_indices { let weight = upstreams[idx].config.weight as u32; @@ -94,7 +94,6 @@ impl UpstreamManager { match self.connect_via_upstream(&upstream, target).await { Ok(stream) => { - // Mark success let mut guard = self.upstreams.write().await; if let Some(u) = guard.get_mut(idx) { if !u.healthy { @@ -106,7 +105,6 @@ impl UpstreamManager { Ok(stream) }, Err(e) => { - // Mark failure let mut guard = self.upstreams.write().await; if let Some(u) = guard.get_mut(idx) { u.fails += 1; @@ -129,18 +127,16 @@ impl UpstreamManager { let socket = create_outgoing_socket_bound(target, bind_ip)?; - // Non-blocking connect logic socket.set_nonblocking(true)?; match socket.connect(&target.into()) { Ok(()) => {}, - Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {}, + Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {}, Err(err) => return Err(ProxyError::Io(err)), } let std_stream: std::net::TcpStream = socket.into(); let stream = TcpStream::from_std(std_stream)?; - // Wait for connection to complete stream.writable().await?; if let Some(e) = stream.take_error()? { return Err(ProxyError::Io(e)); @@ -159,18 +155,16 @@ impl UpstreamManager { let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?; - // Non-blocking connect logic socket.set_nonblocking(true)?; match socket.connect(&proxy_addr.into()) { Ok(()) => {}, - Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {}, + Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {}, Err(err) => return Err(ProxyError::Io(err)), } let std_stream: std::net::TcpStream = socket.into(); let mut stream = TcpStream::from_std(std_stream)?; - // Wait for connection to complete stream.writable().await?; if let Some(e) = stream.take_error()? { return Err(ProxyError::Io(e)); @@ -190,18 +184,16 @@ impl UpstreamManager { let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?; - // Non-blocking connect logic socket.set_nonblocking(true)?; match socket.connect(&proxy_addr.into()) { Ok(()) => {}, - Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {}, + Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {}, Err(err) => return Err(ProxyError::Io(err)), } let std_stream: std::net::TcpStream = socket.into(); let mut stream = TcpStream::from_std(std_stream)?; - // Wait for connection to complete stream.writable().await?; if let Some(e) = stream.take_error()? { return Err(ProxyError::Io(e)); @@ -215,7 +207,6 @@ impl UpstreamManager { /// Background task to check health pub async fn run_health_checks(&self) { - // Simple TCP connect check to a known stable DC (e.g. 149.154.167.50:443 - DC2) let check_target: SocketAddr = "149.154.167.50:443".parse().unwrap(); loop { @@ -246,7 +237,6 @@ impl UpstreamManager { } Ok(Err(e)) => { debug!("Health check failed for {:?}: {}", u.config, e); - // Don't mark unhealthy immediately in background check } Err(_) => { debug!("Health check timeout for {:?}", u.config);