diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index db1123f..7945e70 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -14,6 +14,11 @@ jobs: name: Build runs-on: ubuntu-latest + permissions: + contents: read + actions: write + checks: write + steps: - name: Checkout repository uses: actions/checkout@v4 diff --git a/Cargo.toml b/Cargo.toml index 0f3384f..f2992aa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,15 +1,14 @@ [package] name = "telemt" -version = "1.0.0" -edition = "2021" -rust-version = "1.75" +version = "1.2.0" +edition = "2024" [dependencies] # C libc = "0.2" # Async runtime -tokio = { version = "1.35", features = ["full", "tracing"] } +tokio = { version = "1.42", features = ["full", "tracing"] } tokio-util = { version = "0.7", features = ["codec"] } # Crypto @@ -20,41 +19,41 @@ sha2 = "0.10" sha1 = "0.10" md-5 = "0.10" hmac = "0.12" -crc32fast = "1.3" +crc32fast = "1.4" +zeroize = { version = "1.8", features = ["derive"] } # Network socket2 = { version = "0.5", features = ["all"] } -rustls = "0.22" -# Serial +# Serialization serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" toml = "0.8" # Utils -bytes = "1.5" -thiserror = "1.0" +bytes = "1.9" +thiserror = "2.0" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } parking_lot = "0.12" dashmap = "5.5" lru = "0.12" -rand = "0.8" +rand = "0.9" chrono = { version = "0.4", features = ["serde"] } hex = "0.4" -base64 = "0.21" +base64 = "0.22" url = "2.5" -regex = "1.10" -once_cell = "1.19" +regex = "1.11" crossbeam-queue = "0.3" # HTTP -reqwest = { version = "0.11", features = ["rustls-tls"], default-features = false } +reqwest = { version = "0.12", features = ["rustls-tls"], default-features = false } [dev-dependencies] tokio-test = "0.4" criterion = "0.5" proptest = "1.4" +futures = "0.3" [[bench]] name = "crypto_bench" diff --git a/config.toml b/config.toml index 7b1dc8a..4edd318 100644 --- a/config.toml +++ b/config.toml @@ -6,8 +6,13 @@ show_link = ["hello"] [general] prefer_ipv6 = false fast_mode = true -use_middle_proxy = false -# ad_tag = "..." +use_middle_proxy = true +ad_tag = "00000000000000000000000000000000" + +# Log level: debug | verbose | normal | silent +# Can be overridden with --silent or --log-level CLI flags +# RUST_LOG env var takes absolute priority over all of these +log_level = "normal" [general.modes] classic = false @@ -39,16 +44,16 @@ client_ack = 300 # === Anti-Censorship & Masking === [censorship] -tls_domain = "petrovich.ru" +tls_domain = "google.ru" mask = true mask_port = 443 # mask_host = "petrovich.ru" # Defaults to tls_domain if not set fake_cert_len = 2048 # === Access Control & Users === -# username "hello" is used for example [access] replay_check_len = 65536 +replay_window_secs = 1800 ignore_time_skew = false [access.users] @@ -62,17 +67,13 @@ hello = "00000000000000000000000000000000" # hello = 1073741824 # 1 GB # === Upstreams & Routing === -# By default, direct connection is used, but you can add SOCKS proxy - -# Direct - Default [[upstreams]] type = "direct" enabled = true weight = 10 -# SOCKS5 # [[upstreams]] # type = "socks5" # address = "127.0.0.1:9050" # enabled = false -# weight = 1 +# weight = 1 \ No newline at end of file diff --git a/src/cli.rs b/src/cli.rs new file mode 100644 index 0000000..1440a63 --- /dev/null +++ b/src/cli.rs @@ -0,0 +1,300 @@ +//! CLI commands: --init (fire-and-forget setup) + +use std::fs; +use std::path::{Path, PathBuf}; +use std::process::Command; +use rand::Rng; + +/// Options for the init command +pub struct InitOptions { + pub port: u16, + pub domain: String, + pub secret: Option, + pub username: String, + pub config_dir: PathBuf, + pub no_start: bool, +} + +impl Default for InitOptions { + fn default() -> Self { + Self { + port: 443, + domain: "www.google.com".to_string(), + secret: None, + username: "user".to_string(), + config_dir: PathBuf::from("/etc/telemt"), + no_start: false, + } + } +} + +/// Parse --init subcommand options from CLI args. +/// +/// Returns `Some(InitOptions)` if `--init` was found, `None` otherwise. +pub fn parse_init_args(args: &[String]) -> Option { + if !args.iter().any(|a| a == "--init") { + return None; + } + + let mut opts = InitOptions::default(); + let mut i = 0; + + while i < args.len() { + match args[i].as_str() { + "--port" => { + i += 1; + if i < args.len() { + opts.port = args[i].parse().unwrap_or(443); + } + } + "--domain" => { + i += 1; + if i < args.len() { + opts.domain = args[i].clone(); + } + } + "--secret" => { + i += 1; + if i < args.len() { + opts.secret = Some(args[i].clone()); + } + } + "--user" => { + i += 1; + if i < args.len() { + opts.username = args[i].clone(); + } + } + "--config-dir" => { + i += 1; + if i < args.len() { + opts.config_dir = PathBuf::from(&args[i]); + } + } + "--no-start" => { + opts.no_start = true; + } + _ => {} + } + i += 1; + } + + Some(opts) +} + +/// Run the fire-and-forget setup. +pub fn run_init(opts: InitOptions) -> Result<(), Box> { + eprintln!("[telemt] Fire-and-forget setup"); + eprintln!(); + + // 1. Generate or validate secret + let secret = match opts.secret { + Some(s) => { + if s.len() != 32 || !s.chars().all(|c| c.is_ascii_hexdigit()) { + eprintln!("[error] Secret must be exactly 32 hex characters"); + std::process::exit(1); + } + s + } + None => generate_secret(), + }; + + eprintln!("[+] Secret: {}", secret); + eprintln!("[+] User: {}", opts.username); + eprintln!("[+] Port: {}", opts.port); + eprintln!("[+] Domain: {}", opts.domain); + + // 2. Create config directory + fs::create_dir_all(&opts.config_dir)?; + let config_path = opts.config_dir.join("config.toml"); + + // 3. Write config + let config_content = generate_config(&opts.username, &secret, opts.port, &opts.domain); + fs::write(&config_path, &config_content)?; + eprintln!("[+] Config written to {}", config_path.display()); + + // 4. Write systemd unit + let exe_path = std::env::current_exe() + .unwrap_or_else(|_| PathBuf::from("/usr/local/bin/telemt")); + + let unit_path = Path::new("/etc/systemd/system/telemt.service"); + let unit_content = generate_systemd_unit(&exe_path, &config_path); + + match fs::write(unit_path, &unit_content) { + Ok(()) => { + eprintln!("[+] Systemd unit written to {}", unit_path.display()); + } + Err(e) => { + eprintln!("[!] Cannot write systemd unit (run as root?): {}", e); + eprintln!("[!] Manual unit file content:"); + eprintln!("{}", unit_content); + + // Still print links and config + print_links(&opts.username, &secret, opts.port, &opts.domain); + return Ok(()); + } + } + + // 5. Reload systemd + run_cmd("systemctl", &["daemon-reload"]); + + // 6. Enable service + run_cmd("systemctl", &["enable", "telemt.service"]); + eprintln!("[+] Service enabled"); + + // 7. Start service (unless --no-start) + if !opts.no_start { + run_cmd("systemctl", &["start", "telemt.service"]); + eprintln!("[+] Service started"); + + // Brief delay then check status + std::thread::sleep(std::time::Duration::from_secs(1)); + let status = Command::new("systemctl") + .args(["is-active", "telemt.service"]) + .output(); + + match status { + Ok(out) if out.status.success() => { + eprintln!("[+] Service is running"); + } + _ => { + eprintln!("[!] Service may not have started correctly"); + eprintln!("[!] Check: journalctl -u telemt.service -n 20"); + } + } + } else { + eprintln!("[+] Service not started (--no-start)"); + eprintln!("[+] Start manually: systemctl start telemt.service"); + } + + eprintln!(); + + // 8. Print links + print_links(&opts.username, &secret, opts.port, &opts.domain); + + Ok(()) +} + +fn generate_secret() -> String { + let mut rng = rand::rng(); + let bytes: Vec = (0..16).map(|_| rng.random::()).collect(); + hex::encode(bytes) +} + +fn generate_config(username: &str, secret: &str, port: u16, domain: &str) -> String { + format!( +r#"# Telemt MTProxy — auto-generated config +# Re-run `telemt --init` to regenerate + +show_link = ["{username}"] + +[general] +prefer_ipv6 = false +fast_mode = true +use_middle_proxy = false +log_level = "normal" + +[general.modes] +classic = false +secure = false +tls = true + +[server] +port = {port} +listen_addr_ipv4 = "0.0.0.0" +listen_addr_ipv6 = "::" + +[[server.listeners]] +ip = "0.0.0.0" + +[[server.listeners]] +ip = "::" + +[timeouts] +client_handshake = 15 +tg_connect = 10 +client_keepalive = 60 +client_ack = 300 + +[censorship] +tls_domain = "{domain}" +mask = true +mask_port = 443 +fake_cert_len = 2048 + +[access] +replay_check_len = 65536 +replay_window_secs = 1800 +ignore_time_skew = false + +[access.users] +{username} = "{secret}" + +[[upstreams]] +type = "direct" +enabled = true +weight = 10 +"#, + username = username, + secret = secret, + port = port, + domain = domain, + ) +} + +fn generate_systemd_unit(exe_path: &Path, config_path: &Path) -> String { + format!( +r#"[Unit] +Description=Telemt MTProxy +Documentation=https://github.com/nicepkg/telemt +After=network-online.target +Wants=network-online.target + +[Service] +Type=simple +ExecStart={exe} {config} +Restart=always +RestartSec=5 +LimitNOFILE=65535 +# Security hardening +NoNewPrivileges=true +ProtectSystem=strict +ProtectHome=true +ReadWritePaths=/etc/telemt +PrivateTmp=true + +[Install] +WantedBy=multi-user.target +"#, + exe = exe_path.display(), + config = config_path.display(), + ) +} + +fn run_cmd(cmd: &str, args: &[&str]) { + match Command::new(cmd).args(args).output() { + Ok(output) => { + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + eprintln!("[!] {} {} failed: {}", cmd, args.join(" "), stderr.trim()); + } + } + Err(e) => { + eprintln!("[!] Failed to run {} {}: {}", cmd, args.join(" "), e); + } + } +} + +fn print_links(username: &str, secret: &str, port: u16, domain: &str) { + let domain_hex = hex::encode(domain); + + println!("=== Proxy Links ==="); + println!("[{}]", username); + println!(" EE-TLS: tg://proxy?server=YOUR_SERVER_IP&port={}&secret=ee{}{}", + port, secret, domain_hex); + println!(); + println!("Replace YOUR_SERVER_IP with your server's public IP."); + println!("The proxy will auto-detect and display the correct link on startup."); + println!("Check: journalctl -u telemt.service | head -30"); + println!("==================="); +} \ No newline at end of file diff --git a/src/config/mod.rs b/src/config/mod.rs index 425aeef..38e57c0 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -14,6 +14,7 @@ fn default_port() -> u16 { 443 } fn default_tls_domain() -> String { "www.google.com".to_string() } fn default_mask_port() -> u16 { 443 } fn default_replay_check_len() -> usize { 65536 } +fn default_replay_window_secs() -> u64 { 1800 } fn default_handshake_timeout() -> u64 { 15 } fn default_connect_timeout() -> u64 { 10 } fn default_keepalive() -> u64 { 60 } @@ -28,6 +29,58 @@ fn default_metrics_whitelist() -> Vec { ] } +// ============= Log Level ============= + +/// Logging verbosity level +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum LogLevel { + /// All messages including trace (trace + debug + info + warn + error) + Debug, + /// Detailed operational logs (debug + info + warn + error) + Verbose, + /// Standard operational logs (info + warn + error) + #[default] + Normal, + /// Minimal output: only warnings and errors (warn + error). + /// Proxy links are still printed to stdout via println!. + Silent, +} + +impl LogLevel { + /// Convert to tracing EnvFilter directive string + pub fn to_filter_str(&self) -> &'static str { + match self { + LogLevel::Debug => "trace", + LogLevel::Verbose => "debug", + LogLevel::Normal => "info", + LogLevel::Silent => "warn", + } + } + + /// Parse from a loose string (CLI argument) + pub fn from_str_loose(s: &str) -> Self { + match s.to_lowercase().as_str() { + "debug" | "trace" => LogLevel::Debug, + "verbose" => LogLevel::Verbose, + "normal" | "info" => LogLevel::Normal, + "silent" | "quiet" | "error" | "warn" => LogLevel::Silent, + _ => LogLevel::Normal, + } + } +} + +impl std::fmt::Display for LogLevel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + LogLevel::Debug => write!(f, "debug"), + LogLevel::Verbose => write!(f, "verbose"), + LogLevel::Normal => write!(f, "normal"), + LogLevel::Silent => write!(f, "silent"), + } + } +} + // ============= Sub-Configs ============= #[derive(Debug, Clone, Serialize, Deserialize)] @@ -62,6 +115,9 @@ pub struct GeneralConfig { #[serde(default)] pub ad_tag: Option, + + #[serde(default)] + pub log_level: LogLevel, } impl Default for GeneralConfig { @@ -72,6 +128,7 @@ impl Default for GeneralConfig { fast_mode: true, use_middle_proxy: false, ad_tag: None, + log_level: LogLevel::Normal, } } } @@ -187,6 +244,9 @@ pub struct AccessConfig { #[serde(default = "default_replay_check_len")] pub replay_check_len: usize, + #[serde(default = "default_replay_window_secs")] + pub replay_window_secs: u64, + #[serde(default)] pub ignore_time_skew: bool, } @@ -201,6 +261,7 @@ impl Default for AccessConfig { user_expirations: HashMap::new(), user_data_quota: HashMap::new(), replay_check_len: default_replay_check_len(), + replay_window_secs: default_replay_window_secs(), ignore_time_skew: false, } } @@ -299,20 +360,14 @@ impl ProxyConfig { return Err(ProxyError::Config("tls_domain cannot be empty".to_string())); } - // Warn if using default tls_domain - if config.censorship.tls_domain == "www.google.com" { - tracing::warn!("Using default tls_domain (www.google.com). Consider setting a custom domain in config.toml"); - } - // Default mask_host to tls_domain if not set if config.censorship.mask_host.is_none() { - tracing::info!("mask_host not set, using tls_domain ({}) for masking", config.censorship.tls_domain); config.censorship.mask_host = Some(config.censorship.tls_domain.clone()); } // Random fake_cert_len use rand::Rng; - config.censorship.fake_cert_len = rand::thread_rng().gen_range(1024..4096); + config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096); // Migration: Populate listeners if empty if config.server.listeners.is_empty() { @@ -353,7 +408,6 @@ impl ProxyConfig { return Err(ProxyError::Config("No modes enabled".to_string())); } - // Validate tls_domain format (basic check) if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') { return Err(ProxyError::Config( format!("Invalid tls_domain: '{}'. Must be a valid domain name", self.censorship.tls_domain) diff --git a/src/crypto/aes.rs b/src/crypto/aes.rs index 592a21f..9e123bf 100644 --- a/src/crypto/aes.rs +++ b/src/crypto/aes.rs @@ -1,9 +1,19 @@ //! AES encryption implementations //! //! Provides AES-256-CTR and AES-256-CBC modes for MTProto encryption. +//! +//! ## Zeroize policy +//! +//! - `AesCbc` stores raw key/IV bytes and zeroizes them on drop. +//! - `AesCtr` wraps an opaque `Aes256Ctr` cipher from the `ctr` crate. +//! The expanded key schedule lives inside that type and cannot be +//! zeroized from outside. Callers that hold raw key material (e.g. +//! `HandshakeSuccess`, `ObfuscationParams`) are responsible for +//! zeroizing their own copies. use aes::Aes256; use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}}; +use zeroize::Zeroize; use crate::error::{ProxyError, Result}; type Aes256Ctr = Ctr128BE; @@ -12,7 +22,12 @@ type Aes256Ctr = Ctr128BE; /// AES-256-CTR encryptor/decryptor /// -/// CTR mode is symmetric - encryption and decryption are the same operation. +/// CTR mode is symmetric — encryption and decryption are the same operation. +/// +/// **Zeroize note:** The inner `Aes256Ctr` cipher state (expanded key schedule +/// + counter) is opaque and cannot be zeroized. If you need to protect key +/// material, zeroize the `[u8; 32]` key and `u128` IV at the call site +/// before dropping them. pub struct AesCtr { cipher: Aes256Ctr, } @@ -62,14 +77,23 @@ impl AesCtr { /// AES-256-CBC cipher with proper chaining /// -/// Unlike CTR mode, CBC is NOT symmetric - encryption and decryption +/// Unlike CTR mode, CBC is NOT symmetric — encryption and decryption /// are different operations. This implementation handles CBC chaining /// correctly across multiple blocks. +/// +/// Key and IV are zeroized on drop. pub struct AesCbc { key: [u8; 32], iv: [u8; 16], } +impl Drop for AesCbc { + fn drop(&mut self) { + self.key.zeroize(); + self.iv.zeroize(); + } +} + impl AesCbc { /// AES block size const BLOCK_SIZE: usize = 16; @@ -141,17 +165,9 @@ impl AesCbc { for chunk in data.chunks(Self::BLOCK_SIZE) { let plaintext: [u8; 16] = chunk.try_into().unwrap(); - - // XOR plaintext with previous ciphertext (or IV for first block) let xored = Self::xor_blocks(&plaintext, &prev_ciphertext); - - // Encrypt the XORed block let ciphertext = self.encrypt_block(&xored, &key_schedule); - - // Save for next iteration prev_ciphertext = ciphertext; - - // Append to result result.extend_from_slice(&ciphertext); } @@ -180,17 +196,9 @@ impl AesCbc { for chunk in data.chunks(Self::BLOCK_SIZE) { let ciphertext: [u8; 16] = chunk.try_into().unwrap(); - - // Decrypt the block let decrypted = self.decrypt_block(&ciphertext, &key_schedule); - - // XOR with previous ciphertext (or IV for first block) let plaintext = Self::xor_blocks(&decrypted, &prev_ciphertext); - - // Save current ciphertext for next iteration prev_ciphertext = ciphertext; - - // Append to result result.extend_from_slice(&plaintext); } @@ -217,16 +225,13 @@ impl AesCbc { for i in (0..data.len()).step_by(Self::BLOCK_SIZE) { let block = &mut data[i..i + Self::BLOCK_SIZE]; - // XOR with previous ciphertext for j in 0..Self::BLOCK_SIZE { block[j] ^= prev_ciphertext[j]; } - // Encrypt in-place let block_array: &mut [u8; 16] = block.try_into().unwrap(); *block_array = self.encrypt_block(block_array, &key_schedule); - // Save for next iteration prev_ciphertext = *block_array; } @@ -248,26 +253,20 @@ impl AesCbc { use aes::cipher::KeyInit; let key_schedule = aes::Aes256::new((&self.key).into()); - // For in-place decryption, we need to save ciphertext blocks - // before we overwrite them let mut prev_ciphertext = self.iv; for i in (0..data.len()).step_by(Self::BLOCK_SIZE) { let block = &mut data[i..i + Self::BLOCK_SIZE]; - // Save current ciphertext before modifying let current_ciphertext: [u8; 16] = block.try_into().unwrap(); - // Decrypt in-place let block_array: &mut [u8; 16] = block.try_into().unwrap(); *block_array = self.decrypt_block(block_array, &key_schedule); - // XOR with previous ciphertext for j in 0..Self::BLOCK_SIZE { block[j] ^= prev_ciphertext[j]; } - // Save for next iteration prev_ciphertext = current_ciphertext; } @@ -347,10 +346,8 @@ mod tests { let mut cipher = AesCtr::new(&key, iv); cipher.apply(&mut data); - // Encrypted should be different assert_ne!(&data[..], original); - // Decrypt with fresh cipher let mut cipher = AesCtr::new(&key, iv); cipher.apply(&mut data); @@ -364,7 +361,7 @@ mod tests { let key = [0u8; 32]; let iv = [0u8; 16]; - let original = [0u8; 32]; // 2 blocks + let original = [0u8; 32]; let cipher = AesCbc::new(key, iv); let encrypted = cipher.encrypt(&original).unwrap(); @@ -375,31 +372,25 @@ mod tests { #[test] fn test_aes_cbc_chaining_works() { - // This is the key test - verify CBC chaining is correct let key = [0x42u8; 32]; let iv = [0x00u8; 16]; - // Two IDENTICAL plaintext blocks let plaintext = [0xAAu8; 32]; let cipher = AesCbc::new(key, iv); let ciphertext = cipher.encrypt(&plaintext).unwrap(); - // With proper CBC, identical plaintext blocks produce DIFFERENT ciphertext let block1 = &ciphertext[0..16]; let block2 = &ciphertext[16..32]; assert_ne!( block1, block2, - "CBC chaining broken: identical plaintext blocks produced identical ciphertext. \ - This indicates ECB mode, not CBC!" + "CBC chaining broken: identical plaintext blocks produced identical ciphertext" ); } #[test] fn test_aes_cbc_known_vector() { - // Test with known NIST test vector - // AES-256-CBC with zero key and zero IV let key = [0u8; 32]; let iv = [0u8; 16]; let plaintext = [0u8; 16]; @@ -407,11 +398,9 @@ mod tests { let cipher = AesCbc::new(key, iv); let ciphertext = cipher.encrypt(&plaintext).unwrap(); - // Decrypt and verify roundtrip let decrypted = cipher.decrypt(&ciphertext).unwrap(); assert_eq!(plaintext.as_slice(), decrypted.as_slice()); - // Ciphertext should not be all zeros assert_ne!(ciphertext.as_slice(), plaintext.as_slice()); } @@ -420,7 +409,6 @@ mod tests { let key = [0x12u8; 32]; let iv = [0x34u8; 16]; - // 5 blocks = 80 bytes let plaintext: Vec = (0..80).collect(); let cipher = AesCbc::new(key, iv); @@ -435,7 +423,7 @@ mod tests { let key = [0x12u8; 32]; let iv = [0x34u8; 16]; - let original = [0x56u8; 48]; // 3 blocks + let original = [0x56u8; 48]; let mut buffer = original; let cipher = AesCbc::new(key, iv); @@ -462,41 +450,33 @@ mod tests { fn test_aes_cbc_unaligned_error() { let cipher = AesCbc::new([0u8; 32], [0u8; 16]); - // 15 bytes - not aligned to block size let result = cipher.encrypt(&[0u8; 15]); assert!(result.is_err()); - // 17 bytes - not aligned let result = cipher.encrypt(&[0u8; 17]); assert!(result.is_err()); } #[test] fn test_aes_cbc_avalanche_effect() { - // Changing one bit in plaintext should change entire ciphertext block - // and all subsequent blocks (due to chaining) let key = [0xAB; 32]; let iv = [0xCD; 16]; - let mut plaintext1 = [0u8; 32]; + let plaintext1 = [0u8; 32]; let mut plaintext2 = [0u8; 32]; - plaintext2[0] = 0x01; // Single bit difference in first block + plaintext2[0] = 0x01; let cipher = AesCbc::new(key, iv); let ciphertext1 = cipher.encrypt(&plaintext1).unwrap(); let ciphertext2 = cipher.encrypt(&plaintext2).unwrap(); - // First blocks should be different assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]); - - // Second blocks should ALSO be different (chaining effect) assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]); } #[test] fn test_aes_cbc_iv_matters() { - // Same plaintext with different IVs should produce different ciphertext let key = [0x55; 32]; let plaintext = [0x77u8; 16]; @@ -511,7 +491,6 @@ mod tests { #[test] fn test_aes_cbc_deterministic() { - // Same key, IV, plaintext should always produce same ciphertext let key = [0x99; 32]; let iv = [0x88; 16]; let plaintext = [0x77u8; 32]; @@ -524,6 +503,23 @@ mod tests { assert_eq!(ciphertext1, ciphertext2); } + // ============= Zeroize Tests ============= + + #[test] + fn test_aes_cbc_zeroize_on_drop() { + let key = [0xAA; 32]; + let iv = [0xBB; 16]; + + let cipher = AesCbc::new(key, iv); + // Verify key/iv are set + assert_eq!(cipher.key, [0xAA; 32]); + assert_eq!(cipher.iv, [0xBB; 16]); + + drop(cipher); + // After drop, key/iv are zeroized (can't observe directly, + // but the Drop impl runs without panic) + } + // ============= Error Handling Tests ============= #[test] diff --git a/src/crypto/hash.rs b/src/crypto/hash.rs index 0472018..cf3ba0d 100644 --- a/src/crypto/hash.rs +++ b/src/crypto/hash.rs @@ -1,3 +1,16 @@ +//! Cryptographic hash functions +//! +//! ## Protocol-required algorithms +//! +//! This module exposes MD5 and SHA-1 alongside SHA-256. These weaker +//! hash functions are **required by the Telegram Middle Proxy protocol** +//! (`derive_middleproxy_keys`) and cannot be replaced without breaking +//! compatibility. They are NOT used for any security-sensitive purpose +//! outside of that specific key derivation scheme mandated by Telegram. +//! +//! Static analysis tools (CodeQL, cargo-audit) may flag them — the +//! usages are intentional and protocol-mandated. + use hmac::{Hmac, Mac}; use sha2::Sha256; use md5::Md5; @@ -21,14 +34,16 @@ pub fn sha256_hmac(key: &[u8], data: &[u8]) -> [u8; 32] { mac.finalize().into_bytes().into() } -/// SHA-1 +/// SHA-1 — **protocol-required** by Telegram Middle Proxy key derivation. +/// Not used for general-purpose hashing. pub fn sha1(data: &[u8]) -> [u8; 20] { let mut hasher = Sha1::new(); hasher.update(data); hasher.finalize().into() } -/// MD5 +/// MD5 — **protocol-required** by Telegram Middle Proxy key derivation. +/// Not used for general-purpose hashing. pub fn md5(data: &[u8]) -> [u8; 16] { let mut hasher = Md5::new(); hasher.update(data); @@ -40,7 +55,11 @@ pub fn crc32(data: &[u8]) -> u32 { crc32fast::hash(data) } -/// Middle Proxy Keygen +/// Middle Proxy key derivation +/// +/// Uses MD5 + SHA-1 as mandated by the Telegram Middle Proxy protocol. +/// These algorithms are NOT replaceable here — changing them would break +/// interoperability with Telegram's middle proxy infrastructure. pub fn derive_middleproxy_keys( nonce_srv: &[u8; 16], nonce_clt: &[u8; 16], diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 6339927..dfc2be6 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -6,4 +6,4 @@ pub mod random; pub use aes::{AesCtr, AesCbc}; pub use hash::{sha256, sha256_hmac, sha1, md5, crc32}; -pub use random::{SecureRandom, SECURE_RANDOM}; \ No newline at end of file +pub use random::SecureRandom; \ No newline at end of file diff --git a/src/crypto/random.rs b/src/crypto/random.rs index c179f25..18862ab 100644 --- a/src/crypto/random.rs +++ b/src/crypto/random.rs @@ -3,11 +3,8 @@ use rand::{Rng, RngCore, SeedableRng}; use rand::rngs::StdRng; use parking_lot::Mutex; +use zeroize::Zeroize; use crate::crypto::AesCtr; -use once_cell::sync::Lazy; - -/// Global secure random instance -pub static SECURE_RANDOM: Lazy = Lazy::new(SecureRandom::new); /// Cryptographically secure PRNG with AES-CTR pub struct SecureRandom { @@ -20,18 +17,30 @@ struct SecureRandomInner { buffer: Vec, } +impl Drop for SecureRandomInner { + fn drop(&mut self) { + self.buffer.zeroize(); + } +} + impl SecureRandom { pub fn new() -> Self { - let mut rng = StdRng::from_entropy(); + let mut seed_source = rand::rng(); + let mut rng = StdRng::from_rng(&mut seed_source); let mut key = [0u8; 32]; rng.fill_bytes(&mut key); - let iv: u128 = rng.gen(); + let iv: u128 = rng.random(); + + let cipher = AesCtr::new(&key, iv); + + // Zeroize local key copy — cipher already consumed it + key.zeroize(); Self { inner: Mutex::new(SecureRandomInner { rng, - cipher: AesCtr::new(&key, iv), + cipher, buffer: Vec::with_capacity(1024), }), } @@ -78,7 +87,6 @@ impl SecureRandom { result |= (b as u64) << (i * 8); } - // Mask extra bits if k < 64 { result &= (1u64 << k) - 1; } @@ -107,13 +115,13 @@ impl SecureRandom { /// Generate random u32 pub fn u32(&self) -> u32 { let mut inner = self.inner.lock(); - inner.rng.gen() + inner.rng.random() } /// Generate random u64 pub fn u64(&self) -> u64 { let mut inner = self.inner.lock(); - inner.rng.gen() + inner.rng.random() } } @@ -162,12 +170,10 @@ mod tests { fn test_bits() { let rng = SecureRandom::new(); - // Single bit should be 0 or 1 for _ in 0..100 { assert!(rng.bits(1) <= 1); } - // 8 bits should be 0-255 for _ in 0..100 { assert!(rng.bits(8) <= 255); } @@ -185,10 +191,8 @@ mod tests { } } - // Should have seen all items assert_eq!(seen.len(), 5); - // Empty slice should return None let empty: Vec = vec![]; assert!(rng.choose(&empty).is_none()); } @@ -201,12 +205,10 @@ mod tests { let mut shuffled = original.clone(); rng.shuffle(&mut shuffled); - // Should contain same elements let mut sorted = shuffled.clone(); sorted.sort(); assert_eq!(sorted, original); - // Should be different order (with very high probability) assert_ne!(shuffled, original); } } \ No newline at end of file diff --git a/src/error.rs b/src/error.rs index 6d9bae3..f934672 100644 --- a/src/error.rs +++ b/src/error.rs @@ -118,16 +118,13 @@ pub trait Recoverable { impl Recoverable for StreamError { fn is_recoverable(&self) -> bool { match self { - // Partial operations can be retried Self::PartialRead { .. } | Self::PartialWrite { .. } => true, - // I/O errors depend on kind Self::Io(e) => matches!( e.kind(), std::io::ErrorKind::WouldBlock | std::io::ErrorKind::Interrupted | std::io::ErrorKind::TimedOut ), - // These are not recoverable Self::Poisoned { .. } | Self::BufferOverflow { .. } | Self::InvalidFrame { .. } @@ -137,13 +134,9 @@ impl Recoverable for StreamError { fn can_continue(&self) -> bool { match self { - // Poisoned stream cannot be used Self::Poisoned { .. } => false, - // EOF means stream is done Self::UnexpectedEof => false, - // Buffer overflow is fatal Self::BufferOverflow { .. } => false, - // Others might allow continuation _ => true, } } @@ -383,18 +376,18 @@ mod tests { #[test] fn test_handshake_result() { - let success: HandshakeResult = HandshakeResult::Success(42); + let success: HandshakeResult = HandshakeResult::Success(42); assert!(success.is_success()); assert!(!success.is_bad_client()); - let bad: HandshakeResult = HandshakeResult::BadClient; + let bad: HandshakeResult = HandshakeResult::BadClient { reader: (), writer: () }; assert!(!bad.is_success()); assert!(bad.is_bad_client()); } #[test] fn test_handshake_result_map() { - let success: HandshakeResult = HandshakeResult::Success(42); + let success: HandshakeResult = HandshakeResult::Success(42); let mapped = success.map(|x| x * 2); match mapped { diff --git a/src/main.rs b/src/main.rs index a672ed0..9d2cb84 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,9 +5,10 @@ use std::sync::Arc; use std::time::Duration; use tokio::net::TcpListener; use tokio::signal; -use tracing::{info, error, warn}; +use tracing::{info, error, warn, debug}; use tracing_subscriber::{fmt, EnvFilter}; +mod cli; mod config; mod crypto; mod error; @@ -18,78 +19,168 @@ mod stream; mod transport; mod util; -use crate::config::ProxyConfig; +use crate::config::{ProxyConfig, LogLevel}; use crate::proxy::ClientHandler; use crate::stats::{Stats, ReplayChecker}; +use crate::crypto::SecureRandom; use crate::transport::{create_listener, ListenOptions, UpstreamManager}; use crate::util::ip::detect_ip; use crate::stream::BufferPool; +fn parse_cli() -> (String, bool, Option) { + let mut config_path = "config.toml".to_string(); + let mut silent = false; + let mut log_level: Option = None; + + let args: Vec = std::env::args().skip(1).collect(); + + // Check for --init first (handled before tokio) + if let Some(init_opts) = cli::parse_init_args(&args) { + if let Err(e) = cli::run_init(init_opts) { + eprintln!("[telemt] Init failed: {}", e); + std::process::exit(1); + } + std::process::exit(0); + } + + let mut i = 0; + while i < args.len() { + match args[i].as_str() { + "--silent" | "-s" => { silent = true; } + "--log-level" => { + i += 1; + if i < args.len() { log_level = Some(args[i].clone()); } + } + s if s.starts_with("--log-level=") => { + log_level = Some(s.trim_start_matches("--log-level=").to_string()); + } + "--help" | "-h" => { + eprintln!("Usage: telemt [config.toml] [OPTIONS]"); + eprintln!(); + eprintln!("Options:"); + eprintln!(" --silent, -s Suppress info logs"); + eprintln!(" --log-level debug|verbose|normal|silent"); + eprintln!(" --help, -h Show this help"); + eprintln!(); + eprintln!("Setup (fire-and-forget):"); + eprintln!(" --init Generate config, install systemd service, start"); + eprintln!(" --port Listen port (default: 443)"); + eprintln!(" --domain TLS domain for masking (default: www.google.com)"); + eprintln!(" --secret 32-char hex secret (auto-generated if omitted)"); + eprintln!(" --user Username (default: user)"); + eprintln!(" --config-dir Config directory (default: /etc/telemt)"); + eprintln!(" --no-start Don't start the service after install"); + std::process::exit(0); + } + s if !s.starts_with('-') => { config_path = s.to_string(); } + other => { eprintln!("Unknown option: {}", other); } + } + i += 1; + } + + (config_path, silent, log_level) +} + #[tokio::main] async fn main() -> Result<(), Box> { - // Initialize logging - fmt() - .with_env_filter(EnvFilter::from_default_env().add_directive("info".parse().unwrap())) - .init(); + let (config_path, cli_silent, cli_log_level) = parse_cli(); - // Load config - let config_path = std::env::args().nth(1).unwrap_or_else(|| "config.toml".to_string()); let config = match ProxyConfig::load(&config_path) { Ok(c) => c, Err(e) => { - // If config doesn't exist, try to create default if std::path::Path::new(&config_path).exists() { - error!("Failed to load config: {}", e); + eprintln!("[telemt] Error: {}", e); std::process::exit(1); } else { let default = ProxyConfig::default(); - let toml = toml::to_string_pretty(&default).unwrap(); - std::fs::write(&config_path, toml).unwrap(); - info!("Created default config at {}", config_path); + std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap(); + eprintln!("[telemt] Created default config at {}", config_path); default } } }; - config.validate()?; + if let Err(e) = config.validate() { + eprintln!("[telemt] Invalid config: {}", e); + std::process::exit(1); + } + + let effective_log_level = if cli_silent { + LogLevel::Silent + } else if let Some(ref s) = cli_log_level { + LogLevel::from_str_loose(s) + } else { + config.general.log_level.clone() + }; - // Log loaded configuration for debugging - info!("=== Configuration Loaded ==="); - info!("TLS Domain: {}", config.censorship.tls_domain); - info!("Mask enabled: {}", config.censorship.mask); - info!("Mask host: {}", config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain)); - info!("Mask port: {}", config.censorship.mask_port); - info!("Modes: classic={}, secure={}, tls={}", - config.general.modes.classic, - config.general.modes.secure, - config.general.modes.tls - ); - info!("============================"); + let filter = if std::env::var("RUST_LOG").is_ok() { + EnvFilter::from_default_env() + } else { + EnvFilter::new(effective_log_level.to_filter_str()) + }; + fmt().with_env_filter(filter).init(); + + info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION")); + info!("Log level: {}", effective_log_level); + info!("Modes: classic={} secure={} tls={}", + config.general.modes.classic, + config.general.modes.secure, + config.general.modes.tls); + info!("TLS domain: {}", config.censorship.tls_domain); + info!("Mask: {} -> {}:{}", + config.censorship.mask, + config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain), + config.censorship.mask_port); + + if config.censorship.tls_domain == "www.google.com" { + warn!("Using default tls_domain. Consider setting a custom domain."); + } + + let prefer_ipv6 = config.general.prefer_ipv6; let config = Arc::new(config); let stats = Arc::new(Stats::new()); + let rng = Arc::new(SecureRandom::new()); - // Initialize global ReplayChecker - // Using sharded implementation for better concurrency - let replay_checker = Arc::new(ReplayChecker::new(config.access.replay_check_len)); + let replay_checker = Arc::new(ReplayChecker::new( + config.access.replay_check_len, + Duration::from_secs(config.access.replay_window_secs), + )); - // Initialize Upstream Manager let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone())); - - // Initialize Buffer Pool - // 16KB buffers, max 4096 buffers (~64MB total cached) let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096)); - // Start Health Checks + // Startup DC ping + println!("=== Telegram DC Connectivity ==="); + let ping_results = upstream_manager.ping_all_dcs(prefer_ipv6).await; + for upstream_result in &ping_results { + println!(" via {}", upstream_result.upstream_name); + for dc in &upstream_result.results { + match (&dc.rtt_ms, &dc.error) { + (Some(rtt), _) => { + println!(" DC{} ({:>21}): {:.0}ms", dc.dc_idx, dc.dc_addr, rtt); + } + (None, Some(err)) => { + println!(" DC{} ({:>21}): FAIL ({})", dc.dc_idx, dc.dc_addr, err); + } + _ => { + println!(" DC{} ({:>21}): FAIL", dc.dc_idx, dc.dc_addr); + } + } + } + } + println!("================================"); + + // Background tasks let um_clone = upstream_manager.clone(); - tokio::spawn(async move { - um_clone.run_health_checks().await; - }); + tokio::spawn(async move { um_clone.run_health_checks(prefer_ipv6).await; }); + + let rc_clone = replay_checker.clone(); + tokio::spawn(async move { rc_clone.run_periodic_cleanup().await; }); - // Detect public IP if needed (once at startup) let detected_ip = detect_ip().await; + debug!("Detected IPs: v4={:?} v6={:?}", detected_ip.ipv4, detected_ip.ipv6); - // Start Listeners let mut listeners = Vec::new(); for listener_conf in &config.server.listeners { @@ -104,7 +195,6 @@ async fn main() -> Result<(), Box> { let listener = TcpListener::from_std(socket.into())?; info!("Listening on {}", addr); - // Determine public IP for tg:// links let public_ip = if let Some(ip) = listener_conf.announce_ip { ip } else if listener_conf.ip.is_unspecified() { @@ -117,33 +207,29 @@ async fn main() -> Result<(), Box> { listener_conf.ip }; - // Show links for configured users if !config.show_link.is_empty() { - info!("--- Proxy Links for {} ---", public_ip); + println!("--- Proxy Links ({}) ---", public_ip); for user_name in &config.show_link { if let Some(secret) = config.access.users.get(user_name) { - info!("User: {}", user_name); - + println!("[{}]", user_name); if config.general.modes.classic { - info!(" Classic: tg://proxy?server={}&port={}&secret={}", + println!(" Classic: tg://proxy?server={}&port={}&secret={}", public_ip, config.server.port, secret); } - if config.general.modes.secure { - info!(" DD: tg://proxy?server={}&port={}&secret=dd{}", + println!(" DD: tg://proxy?server={}&port={}&secret=dd{}", public_ip, config.server.port, secret); } - if config.general.modes.tls { let domain_hex = hex::encode(&config.censorship.tls_domain); - info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", + println!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", public_ip, config.server.port, secret, domain_hex); } } else { - warn!("User '{}' specified in show_link not found in users list", user_name); + warn!("User '{}' in show_link not found", user_name); } } - info!("-----------------------------------"); + println!("------------------------"); } listeners.push(listener); @@ -155,17 +241,17 @@ async fn main() -> Result<(), Box> { } if listeners.is_empty() { - error!("No listeners could be started. Exiting."); + error!("No listeners. Exiting."); std::process::exit(1); } - // Accept loop for listener in listeners { let config = config.clone(); let stats = stats.clone(); let upstream_manager = upstream_manager.clone(); let replay_checker = replay_checker.clone(); let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); tokio::spawn(async move { loop { @@ -176,18 +262,14 @@ async fn main() -> Result<(), Box> { let upstream_manager = upstream_manager.clone(); let replay_checker = replay_checker.clone(); let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); tokio::spawn(async move { if let Err(e) = ClientHandler::new( - stream, - peer_addr, - config, - stats, - upstream_manager, - replay_checker, - buffer_pool + stream, peer_addr, config, stats, + upstream_manager, replay_checker, buffer_pool, rng ).run().await { - // Log only relevant errors + debug!(peer = %peer_addr, error = %e, "Connection error"); } }); } @@ -200,7 +282,6 @@ async fn main() -> Result<(), Box> { }); } - // Wait for signal match signal::ctrl_c().await { Ok(()) => info!("Shutting down..."), Err(e) => error!("Signal error: {}", e), diff --git a/src/protocol/constants.rs b/src/protocol/constants.rs index d09473e..7451c83 100644 --- a/src/protocol/constants.rs +++ b/src/protocol/constants.rs @@ -1,13 +1,13 @@ //! Protocol constants and datacenter addresses use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -use once_cell::sync::Lazy; +use std::sync::LazyLock; // ============= Telegram Datacenters ============= pub const TG_DATACENTER_PORT: u16 = 443; -pub static TG_DATACENTERS_V4: Lazy> = Lazy::new(|| { +pub static TG_DATACENTERS_V4: LazyLock> = LazyLock::new(|| { vec![ IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), IpAddr::V4(Ipv4Addr::new(149, 154, 167, 51)), @@ -17,7 +17,7 @@ pub static TG_DATACENTERS_V4: Lazy> = Lazy::new(|| { ] }); -pub static TG_DATACENTERS_V6: Lazy> = Lazy::new(|| { +pub static TG_DATACENTERS_V6: LazyLock> = LazyLock::new(|| { vec![ IpAddr::V6("2001:b28:f23d:f001::a".parse().unwrap()), IpAddr::V6("2001:67c:04e8:f002::a".parse().unwrap()), @@ -29,8 +29,8 @@ pub static TG_DATACENTERS_V6: Lazy> = Lazy::new(|| { // ============= Middle Proxies (for advertising) ============= -pub static TG_MIDDLE_PROXIES_V4: Lazy>> = - Lazy::new(|| { +pub static TG_MIDDLE_PROXIES_V4: LazyLock>> = + LazyLock::new(|| { let mut m = std::collections::HashMap::new(); m.insert(1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]); m.insert(-1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]); @@ -45,8 +45,8 @@ pub static TG_MIDDLE_PROXIES_V4: Lazy>> = - Lazy::new(|| { +pub static TG_MIDDLE_PROXIES_V6: LazyLock>> = + LazyLock::new(|| { let mut m = std::collections::HashMap::new(); m.insert(1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]); m.insert(-1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]); @@ -167,8 +167,6 @@ pub const DEFAULT_ACK_TIMEOUT_SECS: u64 = 300; // ============= Buffer Sizes ============= /// Default buffer size -/// CHANGED: Reduced from 64KB to 16KB to match TLS record size and align with -/// the new buffering strategy for better iOS upload performance. pub const DEFAULT_BUFFER_SIZE: usize = 16384; /// Small buffer size for bad client handling diff --git a/src/protocol/obfuscation.rs b/src/protocol/obfuscation.rs index 4e09942..1c55c5f 100644 --- a/src/protocol/obfuscation.rs +++ b/src/protocol/obfuscation.rs @@ -1,10 +1,13 @@ //! MTProto Obfuscation +use zeroize::Zeroize; use crate::crypto::{sha256, AesCtr}; use crate::error::Result; use super::constants::*; /// Obfuscation parameters from handshake +/// +/// Key material is zeroized on drop. #[derive(Debug, Clone)] pub struct ObfuscationParams { /// Key for decrypting client -> proxy traffic @@ -21,25 +24,31 @@ pub struct ObfuscationParams { pub dc_idx: i16, } +impl Drop for ObfuscationParams { + fn drop(&mut self) { + self.decrypt_key.zeroize(); + self.decrypt_iv.zeroize(); + self.encrypt_key.zeroize(); + self.encrypt_iv.zeroize(); + } +} + impl ObfuscationParams { /// Parse obfuscation parameters from handshake bytes /// Returns None if handshake doesn't match any user secret pub fn from_handshake( handshake: &[u8; HANDSHAKE_LEN], - secrets: &[(String, Vec)], // (username, secret_bytes) + secrets: &[(String, Vec)], ) -> Option<(Self, String)> { - // Extract prekey and IV for decryption let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; - // Reversed for encryption direction let enc_prekey_iv: Vec = dec_prekey_iv.iter().rev().copied().collect(); let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; for (username, secret) in secrets { - // Derive decryption key let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); dec_key_input.extend_from_slice(dec_prekey); dec_key_input.extend_from_slice(secret); @@ -47,26 +56,22 @@ impl ObfuscationParams { let decrypt_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap()); - // Create decryptor and decrypt handshake let mut decryptor = AesCtr::new(&decrypt_key, decrypt_iv); let decrypted = decryptor.decrypt(handshake); - // Check protocol tag let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] .try_into() .unwrap(); let proto_tag = match ProtoTag::from_bytes(tag_bytes) { Some(tag) => tag, - None => continue, // Try next secret + None => continue, }; - // Extract DC index let dc_idx = i16::from_le_bytes( decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap() ); - // Derive encryption key let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); enc_key_input.extend_from_slice(enc_prekey); enc_key_input.extend_from_slice(secret); @@ -123,18 +128,15 @@ pub fn generate_nonce Vec>(mut random_bytes: R) -> [u8; H /// Check if nonce is valid (not matching reserved patterns) pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool { - // Check first byte if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { return false; } - // Check first 4 bytes let first_four: [u8; 4] = nonce[..4].try_into().unwrap(); if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { return false; } - // Check bytes 4-7 let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap(); if RESERVED_NONCE_CONTINUES.contains(&continue_four) { return false; @@ -147,12 +149,10 @@ pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool { pub fn prepare_tg_nonce( nonce: &mut [u8; HANDSHAKE_LEN], proto_tag: ProtoTag, - enc_key_iv: Option<&[u8]>, // For fast mode + enc_key_iv: Option<&[u8]>, ) { - // Set protocol tag nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); - // For fast mode, copy the reversed enc_key_iv if let Some(key_iv) = enc_key_iv { let reversed: Vec = key_iv.iter().rev().copied().collect(); nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&reversed); @@ -161,14 +161,12 @@ pub fn prepare_tg_nonce( /// Encrypt the outgoing nonce for Telegram pub fn encrypt_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec { - // Derive encryption key from the nonce itself let key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let enc_key = sha256(key_iv); let enc_iv = u128::from_be_bytes(key_iv[..IV_LEN].try_into().unwrap()); let mut encryptor = AesCtr::new(&enc_key, enc_iv); - // Only encrypt from PROTO_TAG_POS onwards let mut result = nonce.to_vec(); let encrypted_part = encryptor.encrypt(&nonce[PROTO_TAG_POS..]); result[PROTO_TAG_POS..].copy_from_slice(&encrypted_part); @@ -182,22 +180,18 @@ mod tests { #[test] fn test_is_valid_nonce() { - // Valid nonce let mut valid = [0x42u8; HANDSHAKE_LEN]; valid[4..8].copy_from_slice(&[1, 2, 3, 4]); assert!(is_valid_nonce(&valid)); - // Invalid: starts with 0xef let mut invalid = [0x00u8; HANDSHAKE_LEN]; invalid[0] = 0xef; assert!(!is_valid_nonce(&invalid)); - // Invalid: starts with HEAD let mut invalid = [0x00u8; HANDSHAKE_LEN]; invalid[..4].copy_from_slice(b"HEAD"); assert!(!is_valid_nonce(&invalid)); - // Invalid: bytes 4-7 are zeros let mut invalid = [0x42u8; HANDSHAKE_LEN]; invalid[4..8].copy_from_slice(&[0, 0, 0, 0]); assert!(!is_valid_nonce(&invalid)); diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index 354ee9a..68cd3dc 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -4,7 +4,7 @@ //! for domain fronting. The handshake looks like valid TLS 1.3 but //! actually carries MTProto authentication data. -use crate::crypto::{sha256_hmac, random::SECURE_RANDOM}; +use crate::crypto::{sha256_hmac, SecureRandom}; use crate::error::{ProxyError, Result}; use super::constants::*; use std::time::{SystemTime, UNIX_EPOCH}; @@ -315,8 +315,8 @@ pub fn validate_tls_handshake( /// /// This generates random bytes that look like a valid X25519 public key. /// Since we're not doing real TLS, the actual cryptographic properties don't matter. -pub fn gen_fake_x25519_key() -> [u8; 32] { - let bytes = SECURE_RANDOM.bytes(32); +pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] { + let bytes = rng.bytes(32); bytes.try_into().unwrap() } @@ -333,8 +333,9 @@ pub fn build_server_hello( client_digest: &[u8; TLS_DIGEST_LEN], session_id: &[u8], fake_cert_len: usize, + rng: &SecureRandom, ) -> Vec { - let x25519_key = gen_fake_x25519_key(); + let x25519_key = gen_fake_x25519_key(rng); // Build ServerHello let server_hello = ServerHelloBuilder::new(session_id.to_vec()) @@ -351,7 +352,7 @@ pub fn build_server_hello( ]; // Build fake certificate (Application Data record) - let fake_cert = SECURE_RANDOM.bytes(fake_cert_len); + let fake_cert = rng.bytes(fake_cert_len); let mut app_data_record = Vec::with_capacity(5 + fake_cert_len); app_data_record.push(TLS_RECORD_APPLICATION); app_data_record.extend_from_slice(&TLS_VERSION); @@ -489,8 +490,9 @@ mod tests { #[test] fn test_gen_fake_x25519_key() { - let key1 = gen_fake_x25519_key(); - let key2 = gen_fake_x25519_key(); + let rng = SecureRandom::new(); + let key1 = gen_fake_x25519_key(&rng); + let key2 = gen_fake_x25519_key(&rng); assert_eq!(key1.len(), 32); assert_eq!(key2.len(), 32); @@ -545,7 +547,8 @@ mod tests { let client_digest = [0x42u8; 32]; let session_id = vec![0xAA; 32]; - let response = build_server_hello(secret, &client_digest, &session_id, 2048); + let rng = SecureRandom::new(); + let response = build_server_hello(secret, &client_digest, &session_id, 2048, &rng); // Should have at least 3 records assert!(response.len() > 100); @@ -577,8 +580,9 @@ mod tests { let client_digest = [0x42u8; 32]; let session_id = vec![0xAA; 32]; - let response1 = build_server_hello(secret, &client_digest, &session_id, 1024); - let response2 = build_server_hello(secret, &client_digest, &session_id, 1024); + let rng = SecureRandom::new(); + let response1 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng); + let response2 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng); // Digest position should have non-zero data let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 29ef0cd..107cb7b 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -15,9 +15,8 @@ use crate::protocol::tls; use crate::stats::{Stats, ReplayChecker}; use crate::transport::{configure_client_socket, UpstreamManager}; use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter, BufferPool}; -use crate::crypto::AesCtr; +use crate::crypto::{AesCtr, SecureRandom}; -// Use absolute paths to avoid confusion use crate::proxy::handshake::{ handle_tls_handshake, handle_mtproto_handshake, HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce, @@ -25,10 +24,8 @@ use crate::proxy::handshake::{ use crate::proxy::relay::relay_bidirectional; use crate::proxy::masking::handle_bad_client; -/// Client connection handler (builder struct) pub struct ClientHandler; -/// Running client handler with stream and context pub struct RunningClientHandler { stream: TcpStream, peer: SocketAddr, @@ -37,10 +34,10 @@ pub struct RunningClientHandler { replay_checker: Arc, upstream_manager: Arc, buffer_pool: Arc, + rng: Arc, } impl ClientHandler { - /// Create new client handler instance pub fn new( stream: TcpStream, peer: SocketAddr, @@ -49,28 +46,22 @@ impl ClientHandler { upstream_manager: Arc, replay_checker: Arc, buffer_pool: Arc, + rng: Arc, ) -> RunningClientHandler { RunningClientHandler { - stream, - peer, - config, - stats, - replay_checker, - upstream_manager, - buffer_pool, + stream, peer, config, stats, replay_checker, + upstream_manager, buffer_pool, rng, } } } impl RunningClientHandler { - /// Run the client handler pub async fn run(mut self) -> Result<()> { self.stats.increment_connects_all(); let peer = self.peer; debug!(peer = %peer, "New connection"); - // Configure socket if let Err(e) = configure_client_socket( &self.stream, self.config.timeouts.client_keepalive, @@ -79,16 +70,10 @@ impl RunningClientHandler { debug!(peer = %peer, error = %e, "Failed to configure client socket"); } - // Perform handshake with timeout let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake); - - // Clone stats for error handling block let stats = self.stats.clone(); - let result = timeout( - handshake_timeout, - self.do_handshake() - ).await; + let result = timeout(handshake_timeout, self.do_handshake()).await; match result { Ok(Ok(())) => { @@ -107,16 +92,14 @@ impl RunningClientHandler { } } - /// Perform handshake and relay async fn do_handshake(mut self) -> Result<()> { - // Read first bytes to determine handshake type let mut first_bytes = [0u8; 5]; self.stream.read_exact(&mut first_bytes).await?; let is_tls = tls::is_tls_handshake(&first_bytes[..3]); let peer = self.peer; - debug!(peer = %peer, is_tls = is_tls, first_bytes = %hex::encode(&first_bytes), "Handshake type detected"); + debug!(peer = %peer, is_tls = is_tls, "Handshake type detected"); if is_tls { self.handle_tls_client(first_bytes).await @@ -125,14 +108,9 @@ impl RunningClientHandler { } } - /// Handle TLS-wrapped client - async fn handle_tls_client( - mut self, - first_bytes: [u8; 5], - ) -> Result<()> { + async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> { let peer = self.peer; - // Read TLS handshake length let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake"); @@ -140,34 +118,25 @@ impl RunningClientHandler { if tls_len < 512 { debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); self.stats.increment_connects_bad(); - // FIX: Split stream into reader/writer for handle_bad_client let (reader, writer) = self.stream.into_split(); handle_bad_client(reader, writer, &first_bytes, &self.config).await; return Ok(()); } - // Read full TLS handshake let mut handshake = vec![0u8; 5 + tls_len]; handshake[..5].copy_from_slice(&first_bytes); self.stream.read_exact(&mut handshake[5..]).await?; - // Extract fields before consuming self.stream let config = self.config.clone(); let replay_checker = self.replay_checker.clone(); let stats = self.stats.clone(); let buffer_pool = self.buffer_pool.clone(); - // Split stream for reading/writing let (read_half, write_half) = self.stream.into_split(); - // Handle TLS handshake let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( - &handshake, - read_half, - write_half, - peer, - &config, - &replay_checker, + &handshake, read_half, write_half, peer, + &config, &replay_checker, &self.rng, ).await { HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader, writer } => { @@ -178,83 +147,56 @@ impl RunningClientHandler { HandshakeResult::Error(e) => return Err(e), }; - // Read MTProto handshake through TLS debug!(peer = %peer, "Reading MTProto handshake through TLS"); let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?; let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into() .map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?; - // Handle MTProto handshake let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( - &mtproto_handshake, - tls_reader, - tls_writer, - peer, - &config, - &replay_checker, - true, + &mtproto_handshake, tls_reader, tls_writer, peer, + &config, &replay_checker, true, ).await { HandshakeResult::Success(result) => result, - HandshakeResult::BadClient { reader, writer } => { + HandshakeResult::BadClient { reader: _, writer: _ } => { stats.increment_connects_bad(); - // Valid TLS but invalid MTProto - drop - debug!(peer = %peer, "Valid TLS but invalid MTProto handshake - dropping"); + debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); return Ok(()); } HandshakeResult::Error(e) => return Err(e), }; Self::handle_authenticated_static( - crypto_reader, - crypto_writer, - success, - self.upstream_manager, - self.stats, - self.config, - buffer_pool + crypto_reader, crypto_writer, success, + self.upstream_manager, self.stats, self.config, + buffer_pool, self.rng, ).await } - /// Handle direct (non-TLS) client - async fn handle_direct_client( - mut self, - first_bytes: [u8; 5], - ) -> Result<()> { + async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> { let peer = self.peer; - // Check if non-TLS modes are enabled if !self.config.general.modes.classic && !self.config.general.modes.secure { debug!(peer = %peer, "Non-TLS modes disabled"); self.stats.increment_connects_bad(); - // FIX: Split stream into reader/writer for handle_bad_client let (reader, writer) = self.stream.into_split(); handle_bad_client(reader, writer, &first_bytes, &self.config).await; return Ok(()); } - // Read rest of handshake let mut handshake = [0u8; HANDSHAKE_LEN]; handshake[..5].copy_from_slice(&first_bytes); self.stream.read_exact(&mut handshake[5..]).await?; - // Extract fields let config = self.config.clone(); let replay_checker = self.replay_checker.clone(); let stats = self.stats.clone(); let buffer_pool = self.buffer_pool.clone(); - // Split stream let (read_half, write_half) = self.stream.into_split(); - // Handle MTProto handshake let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( - &handshake, - read_half, - write_half, - peer, - &config, - &replay_checker, - false, + &handshake, read_half, write_half, peer, + &config, &replay_checker, false, ).await { HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader, writer } => { @@ -266,17 +208,12 @@ impl RunningClientHandler { }; Self::handle_authenticated_static( - crypto_reader, - crypto_writer, - success, - self.upstream_manager, - self.stats, - self.config, - buffer_pool + crypto_reader, crypto_writer, success, + self.upstream_manager, self.stats, self.config, + buffer_pool, self.rng, ).await } - /// Static version of handle_authenticated_inner async fn handle_authenticated_static( client_reader: CryptoReader, client_writer: CryptoWriter, @@ -285,6 +222,7 @@ impl RunningClientHandler { stats: Arc, config: Arc, buffer_pool: Arc, + rng: Arc, ) -> Result<()> where R: AsyncRead + Unpin + Send + 'static, @@ -292,13 +230,11 @@ impl RunningClientHandler { { let user = &success.user; - // Check user limits if let Err(e) = Self::check_user_limits_static(user, &config, &stats) { warn!(user = %user, error = %e, "User limit exceeded"); return Err(e); } - // Get datacenter address let dc_addr = Self::get_dc_addr_static(success.dc_idx, &config)?; info!( @@ -307,71 +243,54 @@ impl RunningClientHandler { dc = success.dc_idx, dc_addr = %dc_addr, proto = ?success.proto_tag, - fast_mode = config.general.fast_mode, "Connecting to Telegram" ); - // Connect to Telegram via UpstreamManager - let tg_stream = upstream_manager.connect(dc_addr).await?; + // Pass dc_idx for latency-based upstream selection + let tg_stream = upstream_manager.connect(dc_addr, Some(success.dc_idx)).await?; - debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected to Telegram, performing handshake"); + debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake"); - // Perform Telegram handshake and get crypto streams let (tg_reader, tg_writer) = Self::do_tg_handshake_static( - tg_stream, - &success, - &config, + tg_stream, &success, &config, rng.as_ref(), ).await?; - debug!(peer = %success.peer, "Telegram handshake complete, starting relay"); + debug!(peer = %success.peer, "TG handshake complete, starting relay"); - // Update stats stats.increment_user_connects(user); stats.increment_user_curr_connects(user); - // Relay traffic using buffer pool let relay_result = relay_bidirectional( - client_reader, - client_writer, - tg_reader, - tg_writer, - user, - Arc::clone(&stats), - buffer_pool, + client_reader, client_writer, + tg_reader, tg_writer, + user, Arc::clone(&stats), buffer_pool, ).await; - // Update stats stats.decrement_user_curr_connects(user); match &relay_result { - Ok(()) => debug!(user = %user, peer = %success.peer, "Relay completed normally"), - Err(e) => debug!(user = %user, peer = %success.peer, error = %e, "Relay ended with error"), + Ok(()) => debug!(user = %user, "Relay completed"), + Err(e) => debug!(user = %user, error = %e, "Relay ended with error"), } relay_result } - /// Check user limits (static version) fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> { - // Check expiration if let Some(expiration) = config.access.user_expirations.get(user) { if chrono::Utc::now() > *expiration { return Err(ProxyError::UserExpired { user: user.to_string() }); } } - // Check connection limit if let Some(limit) = config.access.user_max_tcp_conns.get(user) { - let current = stats.get_user_curr_connects(user); - if current >= *limit as u64 { + if stats.get_user_curr_connects(user) >= *limit as u64 { return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() }); } } - // Check data quota if let Some(quota) = config.access.user_data_quota.get(user) { - let used = stats.get_user_total_octets(user); - if used >= *quota { + if stats.get_user_total_octets(user) >= *quota { return Err(ProxyError::DataQuotaExceeded { user: user.to_string() }); } } @@ -379,7 +298,6 @@ impl RunningClientHandler { Ok(()) } - /// Get datacenter address by index (static version) fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { let idx = (dc_idx.abs() - 1) as usize; @@ -396,45 +314,39 @@ impl RunningClientHandler { )) } - /// Perform handshake with Telegram server (static version) async fn do_tg_handshake_static( mut stream: TcpStream, success: &HandshakeSuccess, config: &ProxyConfig, + rng: &SecureRandom, ) -> Result<(CryptoReader, CryptoWriter)> { - // Generate nonce with keys for TG let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce( success.proto_tag, - &success.dec_key, // Client's dec key + &success.dec_key, success.dec_iv, + rng, config.general.fast_mode, ); - // Encrypt nonce let encrypted_nonce = encrypt_tg_nonce(&nonce); debug!( peer = %success.peer, nonce_head = %hex::encode(&nonce[..16]), - encrypted_head = %hex::encode(&encrypted_nonce[..16]), "Sending nonce to Telegram" ); - // Send to Telegram stream.write_all(&encrypted_nonce).await?; stream.flush().await?; - debug!(peer = %success.peer, "Nonce sent to Telegram"); - - // Split stream and wrap with crypto let (read_half, write_half) = stream.into_split(); let decryptor = AesCtr::new(&tg_dec_key, tg_dec_iv); let encryptor = AesCtr::new(&tg_enc_key, tg_enc_iv); - let tg_reader = CryptoReader::new(read_half, decryptor); - let tg_writer = CryptoWriter::new(write_half, encryptor); - - Ok((tg_reader, tg_writer)) + Ok(( + CryptoReader::new(read_half, decryptor), + CryptoWriter::new(write_half, encryptor), + )) } } \ No newline at end of file diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index bb0ad1a..81814e2 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -1,11 +1,11 @@ -//! MTProto Handshake Magics +//! MTProto Handshake use std::net::SocketAddr; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, warn, trace, info}; +use zeroize::Zeroize; -use crate::crypto::{sha256, AesCtr}; -use crate::crypto::random::SECURE_RANDOM; +use crate::crypto::{sha256, AesCtr, SecureRandom}; use crate::protocol::constants::*; use crate::protocol::tls; use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter}; @@ -14,6 +14,9 @@ use crate::stats::ReplayChecker; use crate::config::ProxyConfig; /// Result of successful handshake +/// +/// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is +/// zeroized on drop. #[derive(Debug, Clone)] pub struct HandshakeSuccess { /// Authenticated user name @@ -34,6 +37,15 @@ pub struct HandshakeSuccess { pub is_tls: bool, } +impl Drop for HandshakeSuccess { + fn drop(&mut self) { + self.dec_key.zeroize(); + self.dec_iv.zeroize(); + self.enc_key.zeroize(); + self.enc_iv.zeroize(); + } +} + /// Handle fake TLS handshake pub async fn handle_tls_handshake( handshake: &[u8], @@ -42,6 +54,7 @@ pub async fn handle_tls_handshake( peer: SocketAddr, config: &ProxyConfig, replay_checker: &ReplayChecker, + rng: &SecureRandom, ) -> HandshakeResult<(FakeTlsReader, FakeTlsWriter, String), R, W> where R: AsyncRead + Unpin, @@ -49,30 +62,25 @@ where { debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake"); - // Check minimum length if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { debug!(peer = %peer, "TLS handshake too short"); return HandshakeResult::BadClient { reader, writer }; } - // Extract digest for replay check let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]; let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN]; - // Check for replay if replay_checker.check_tls_digest(digest_half) { warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); return HandshakeResult::BadClient { reader, writer }; } - // Build secrets list let secrets: Vec<(String, Vec)> = config.access.users.iter() .filter_map(|(name, hex)| { hex::decode(hex).ok().map(|bytes| (name.clone(), bytes)) }) .collect(); - // Validate handshake let validation = match tls::validate_tls_handshake( handshake, &secrets, @@ -89,18 +97,17 @@ where } }; - // Get secret for response let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { Some((_, s)) => s, None => return HandshakeResult::BadClient { reader, writer }, }; - // Build and send response let response = tls::build_server_hello( secret, &validation.digest, &validation.session_id, config.censorship.fake_cert_len, + rng, ); debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello"); @@ -115,7 +122,6 @@ where return HandshakeResult::Error(ProxyError::Io(e)); } - // Record for replay protection only after successful handshake replay_checker.add_tls_digest(digest_half); info!( @@ -147,26 +153,21 @@ where { trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes"); - // Extract prekey and IV let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; - // Check for replay if replay_checker.check_handshake(dec_prekey_iv) { warn!(peer = %peer, "MTProto replay attack detected"); return HandshakeResult::BadClient { reader, writer }; } - // Reversed for encryption direction let enc_prekey_iv: Vec = dec_prekey_iv.iter().rev().copied().collect(); - // Try each user's secret for (user, secret_hex) in &config.access.users { let secret = match hex::decode(secret_hex) { Ok(s) => s, Err(_) => continue, }; - // Derive decryption key let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; @@ -177,11 +178,9 @@ where let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap()); - // Decrypt handshake to check protocol tag let mut decryptor = AesCtr::new(&dec_key, dec_iv); let decrypted = decryptor.decrypt(handshake); - // Check protocol tag let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] .try_into() .unwrap(); @@ -191,7 +190,6 @@ where None => continue, }; - // Check if mode is enabled let mode_ok = match proto_tag { ProtoTag::Secure => { if is_tls { config.general.modes.tls } else { config.general.modes.secure } @@ -204,12 +202,10 @@ where continue; } - // Extract DC index let dc_idx = i16::from_le_bytes( decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap() ); - // Derive encryption key let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; @@ -220,10 +216,8 @@ where let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap()); - // Record for replay protection replay_checker.add_handshake(dec_prekey_iv); - // Create new cipher instances let decryptor = AesCtr::new(&dec_key, dec_iv); let encryptor = AesCtr::new(&enc_key, enc_iv); @@ -264,10 +258,11 @@ pub fn generate_tg_nonce( proto_tag: ProtoTag, client_dec_key: &[u8; 32], client_dec_iv: u128, + rng: &SecureRandom, fast_mode: bool, ) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) { loop { - let bytes = SECURE_RANDOM.bytes(HANDSHAKE_LEN); + let bytes = rng.bytes(HANDSHAKE_LEN); let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap(); if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; } @@ -323,13 +318,12 @@ mod tests { let client_dec_key = [0x42u8; 32]; let client_dec_iv = 12345u128; - let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = - generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false); + let rng = SecureRandom::new(); + let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = + generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, &rng, false); - // Check length assert_eq!(nonce.len(), HANDSHAKE_LEN); - // Check proto tag is set let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap(); assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure)); } @@ -339,17 +333,35 @@ mod tests { let client_dec_key = [0x42u8; 32]; let client_dec_iv = 12345u128; + let rng = SecureRandom::new(); let (nonce, _, _, _, _) = - generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false); + generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, &rng, false); let encrypted = encrypt_tg_nonce(&nonce); assert_eq!(encrypted.len(), HANDSHAKE_LEN); - - // First PROTO_TAG_POS bytes should be unchanged assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]); - - // Rest should be different (encrypted) assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]); } + + #[test] + fn test_handshake_success_zeroize_on_drop() { + let success = HandshakeSuccess { + user: "test".to_string(), + dc_idx: 2, + proto_tag: ProtoTag::Secure, + dec_key: [0xAA; 32], + dec_iv: 0xBBBBBBBB, + enc_key: [0xCC; 32], + enc_iv: 0xDDDDDDDD, + peer: "127.0.0.1:1234".parse().unwrap(), + is_tls: true, + }; + + assert_eq!(success.dec_key, [0xAA; 32]); + assert_eq!(success.enc_key, [0xCC; 32]); + + drop(success); + // Drop impl zeroizes key material without panic + } } \ No newline at end of file diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 9fa495d..fb30742 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -1,31 +1,28 @@ -//! Statistics +//! Statistics and replay protection use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; -use std::time::Instant; +use std::time::{Instant, Duration}; use dashmap::DashMap; -use parking_lot::{RwLock, Mutex}; +use parking_lot::Mutex; use lru::LruCache; use std::num::NonZeroUsize; use std::hash::{Hash, Hasher}; use std::collections::hash_map::DefaultHasher; +use std::collections::VecDeque; +use tracing::debug; + +// ============= Stats ============= -/// Thread-safe statistics #[derive(Default)] pub struct Stats { - // Global counters connects_all: AtomicU64, connects_bad: AtomicU64, handshake_timeouts: AtomicU64, - - // Per-user stats user_stats: DashMap, - - // Start time - start_time: RwLock>, + start_time: parking_lot::RwLock>, } -/// Per-user statistics #[derive(Default)] pub struct UserStats { pub connects: AtomicU64, @@ -43,42 +40,20 @@ impl Stats { stats } - // Global stats - pub fn increment_connects_all(&self) { - self.connects_all.fetch_add(1, Ordering::Relaxed); - } + pub fn increment_connects_all(&self) { self.connects_all.fetch_add(1, Ordering::Relaxed); } + pub fn increment_connects_bad(&self) { self.connects_bad.fetch_add(1, Ordering::Relaxed); } + pub fn increment_handshake_timeouts(&self) { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); } + pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) } + pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) } - pub fn increment_connects_bad(&self) { - self.connects_bad.fetch_add(1, Ordering::Relaxed); - } - - pub fn increment_handshake_timeouts(&self) { - self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); - } - - pub fn get_connects_all(&self) -> u64 { - self.connects_all.load(Ordering::Relaxed) - } - - pub fn get_connects_bad(&self) -> u64 { - self.connects_bad.load(Ordering::Relaxed) - } - - // User stats pub fn increment_user_connects(&self, user: &str) { - self.user_stats - .entry(user.to_string()) - .or_default() - .connects - .fetch_add(1, Ordering::Relaxed); + self.user_stats.entry(user.to_string()).or_default() + .connects.fetch_add(1, Ordering::Relaxed); } pub fn increment_user_curr_connects(&self, user: &str) { - self.user_stats - .entry(user.to_string()) - .or_default() - .curr_connects - .fetch_add(1, Ordering::Relaxed); + self.user_stats.entry(user.to_string()).or_default() + .curr_connects.fetch_add(1, Ordering::Relaxed); } pub fn decrement_user_curr_connects(&self, user: &str) { @@ -88,47 +63,33 @@ impl Stats { } pub fn get_user_curr_connects(&self, user: &str) -> u64 { - self.user_stats - .get(user) + self.user_stats.get(user) .map(|s| s.curr_connects.load(Ordering::Relaxed)) .unwrap_or(0) } pub fn add_user_octets_from(&self, user: &str, bytes: u64) { - self.user_stats - .entry(user.to_string()) - .or_default() - .octets_from_client - .fetch_add(bytes, Ordering::Relaxed); + self.user_stats.entry(user.to_string()).or_default() + .octets_from_client.fetch_add(bytes, Ordering::Relaxed); } pub fn add_user_octets_to(&self, user: &str, bytes: u64) { - self.user_stats - .entry(user.to_string()) - .or_default() - .octets_to_client - .fetch_add(bytes, Ordering::Relaxed); + self.user_stats.entry(user.to_string()).or_default() + .octets_to_client.fetch_add(bytes, Ordering::Relaxed); } pub fn increment_user_msgs_from(&self, user: &str) { - self.user_stats - .entry(user.to_string()) - .or_default() - .msgs_from_client - .fetch_add(1, Ordering::Relaxed); + self.user_stats.entry(user.to_string()).or_default() + .msgs_from_client.fetch_add(1, Ordering::Relaxed); } pub fn increment_user_msgs_to(&self, user: &str) { - self.user_stats - .entry(user.to_string()) - .or_default() - .msgs_to_client - .fetch_add(1, Ordering::Relaxed); + self.user_stats.entry(user.to_string()).or_default() + .msgs_to_client.fetch_add(1, Ordering::Relaxed); } pub fn get_user_total_octets(&self, user: &str) -> u64 { - self.user_stats - .get(user) + self.user_stats.get(user) .map(|s| { s.octets_from_client.load(Ordering::Relaxed) + s.octets_to_client.load(Ordering::Relaxed) @@ -143,57 +104,209 @@ impl Stats { } } -/// Sharded Replay attack checker using LRU cache -/// Uses multiple independent LRU caches to reduce lock contention +// ============= Replay Checker ============= + pub struct ReplayChecker { - shards: Vec, ()>>>, + shards: Vec>, shard_mask: usize, + window: Duration, + checks: AtomicU64, + hits: AtomicU64, + additions: AtomicU64, + cleanups: AtomicU64, } -impl ReplayChecker { - /// Create new replay checker with specified capacity per shard - /// Total capacity = capacity * num_shards - pub fn new(total_capacity: usize) -> Self { - // Use 64 shards for good concurrency - let num_shards = 64; - let shard_capacity = (total_capacity / num_shards).max(1); - let cap = NonZeroUsize::new(shard_capacity).unwrap(); - - let mut shards = Vec::with_capacity(num_shards); - for _ in 0..num_shards { - shards.push(Mutex::new(LruCache::new(cap))); - } - +struct ReplayEntry { + seen_at: Instant, + seq: u64, +} + +struct ReplayShard { + cache: LruCache, ReplayEntry>, + queue: VecDeque<(Instant, Box<[u8]>, u64)>, + seq_counter: u64, +} + +impl ReplayShard { + fn new(cap: NonZeroUsize) -> Self { Self { - shards, - shard_mask: num_shards - 1, + cache: LruCache::new(cap), + queue: VecDeque::with_capacity(cap.get()), + seq_counter: 0, } } - fn get_shard(&self, key: &[u8]) -> usize { + fn next_seq(&mut self) -> u64 { + self.seq_counter += 1; + self.seq_counter + } + + fn cleanup(&mut self, now: Instant, window: Duration) { + if window.is_zero() { + return; + } + let cutoff = now.checked_sub(window).unwrap_or(now); + + while let Some((ts, _, _)) = self.queue.front() { + if *ts >= cutoff { + break; + } + let (_, key, queue_seq) = self.queue.pop_front().unwrap(); + + // Use key.as_ref() to get &[u8] — avoids Borrow ambiguity + // between Borrow<[u8]> and Borrow> + if let Some(entry) = self.cache.peek(key.as_ref()) { + if entry.seq == queue_seq { + self.cache.pop(key.as_ref()); + } + } + } + } + + fn check(&mut self, key: &[u8], now: Instant, window: Duration) -> bool { + self.cleanup(now, window); + // key is &[u8], resolves Q=[u8] via Box<[u8]>: Borrow<[u8]> + self.cache.get(key).is_some() + } + + fn add(&mut self, key: &[u8], now: Instant, window: Duration) { + self.cleanup(now, window); + + let seq = self.next_seq(); + let boxed_key: Box<[u8]> = key.into(); + + self.cache.put(boxed_key.clone(), ReplayEntry { seen_at: now, seq }); + self.queue.push_back((now, boxed_key, seq)); + } + + fn len(&self) -> usize { + self.cache.len() + } +} + +impl ReplayChecker { + pub fn new(total_capacity: usize, window: Duration) -> Self { + let num_shards = 64; + let shard_capacity = (total_capacity / num_shards).max(1); + let cap = NonZeroUsize::new(shard_capacity).unwrap(); + + let mut shards = Vec::with_capacity(num_shards); + for _ in 0..num_shards { + shards.push(Mutex::new(ReplayShard::new(cap))); + } + + Self { + shards, + shard_mask: num_shards - 1, + window, + checks: AtomicU64::new(0), + hits: AtomicU64::new(0), + additions: AtomicU64::new(0), + cleanups: AtomicU64::new(0), + } + } + + fn get_shard_idx(&self, key: &[u8]) -> usize { let mut hasher = DefaultHasher::new(); key.hash(&mut hasher); (hasher.finish() as usize) & self.shard_mask } + + fn check(&self, data: &[u8]) -> bool { + self.checks.fetch_add(1, Ordering::Relaxed); + let idx = self.get_shard_idx(data); + let mut shard = self.shards[idx].lock(); + let found = shard.check(data, Instant::now(), self.window); + if found { + self.hits.fetch_add(1, Ordering::Relaxed); + } + found + } + + fn add(&self, data: &[u8]) { + self.additions.fetch_add(1, Ordering::Relaxed); + let idx = self.get_shard_idx(data); + let mut shard = self.shards[idx].lock(); + shard.add(data, Instant::now(), self.window); + } + + pub fn check_handshake(&self, data: &[u8]) -> bool { self.check(data) } + pub fn add_handshake(&self, data: &[u8]) { self.add(data) } + pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check(data) } + pub fn add_tls_digest(&self, data: &[u8]) { self.add(data) } - pub fn check_handshake(&self, data: &[u8]) -> bool { - let shard_idx = self.get_shard(data); - self.shards[shard_idx].lock().contains(&data.to_vec()) + pub fn stats(&self) -> ReplayStats { + let mut total_entries = 0; + let mut total_queue_len = 0; + for shard in &self.shards { + let s = shard.lock(); + total_entries += s.cache.len(); + total_queue_len += s.queue.len(); + } + + ReplayStats { + total_entries, + total_queue_len, + total_checks: self.checks.load(Ordering::Relaxed), + total_hits: self.hits.load(Ordering::Relaxed), + total_additions: self.additions.load(Ordering::Relaxed), + total_cleanups: self.cleanups.load(Ordering::Relaxed), + num_shards: self.shards.len(), + window_secs: self.window.as_secs(), + } } - pub fn add_handshake(&self, data: &[u8]) { - let shard_idx = self.get_shard(data); - self.shards[shard_idx].lock().put(data.to_vec(), ()); + pub async fn run_periodic_cleanup(&self) { + let interval = if self.window.as_secs() > 60 { + Duration::from_secs(30) + } else { + Duration::from_secs(self.window.as_secs().max(1) / 2) + }; + + loop { + tokio::time::sleep(interval).await; + + let now = Instant::now(); + let mut cleaned = 0usize; + + for shard_mutex in &self.shards { + let mut shard = shard_mutex.lock(); + let before = shard.len(); + shard.cleanup(now, self.window); + let after = shard.len(); + cleaned += before.saturating_sub(after); + } + + self.cleanups.fetch_add(1, Ordering::Relaxed); + + if cleaned > 0 { + debug!(cleaned = cleaned, "Replay checker: periodic cleanup"); + } + } + } +} + +#[derive(Debug, Clone)] +pub struct ReplayStats { + pub total_entries: usize, + pub total_queue_len: usize, + pub total_checks: u64, + pub total_hits: u64, + pub total_additions: u64, + pub total_cleanups: u64, + pub num_shards: usize, + pub window_secs: u64, +} + +impl ReplayStats { + pub fn hit_rate(&self) -> f64 { + if self.total_checks == 0 { 0.0 } + else { (self.total_hits as f64 / self.total_checks as f64) * 100.0 } } - pub fn check_tls_digest(&self, data: &[u8]) -> bool { - let shard_idx = self.get_shard(data); - self.shards[shard_idx].lock().contains(&data.to_vec()) - } - - pub fn add_tls_digest(&self, data: &[u8]) { - let shard_idx = self.get_shard(data); - self.shards[shard_idx].lock().put(data.to_vec(), ()); + pub fn ghost_ratio(&self) -> f64 { + if self.total_entries == 0 { 0.0 } + else { self.total_queue_len as f64 / self.total_entries as f64 } } } @@ -204,28 +317,60 @@ mod tests { #[test] fn test_stats_shared_counters() { let stats = Arc::new(Stats::new()); - - let stats1 = Arc::clone(&stats); - let stats2 = Arc::clone(&stats); - - stats1.increment_connects_all(); - stats2.increment_connects_all(); - stats1.increment_connects_all(); - + stats.increment_connects_all(); + stats.increment_connects_all(); + stats.increment_connects_all(); assert_eq!(stats.get_connects_all(), 3); } #[test] - fn test_replay_checker_sharding() { - let checker = ReplayChecker::new(100); - let data1 = b"test1"; - let data2 = b"test2"; - - checker.add_handshake(data1); - assert!(checker.check_handshake(data1)); - assert!(!checker.check_handshake(data2)); - - checker.add_handshake(data2); - assert!(checker.check_handshake(data2)); + fn test_replay_checker_basic() { + let checker = ReplayChecker::new(100, Duration::from_secs(60)); + assert!(!checker.check_handshake(b"test1")); + checker.add_handshake(b"test1"); + assert!(checker.check_handshake(b"test1")); + assert!(!checker.check_handshake(b"test2")); + } + + #[test] + fn test_replay_checker_duplicate_add() { + let checker = ReplayChecker::new(100, Duration::from_secs(60)); + checker.add_handshake(b"dup"); + checker.add_handshake(b"dup"); + assert!(checker.check_handshake(b"dup")); + } + + #[test] + fn test_replay_checker_expiration() { + let checker = ReplayChecker::new(100, Duration::from_millis(50)); + checker.add_handshake(b"expire"); + assert!(checker.check_handshake(b"expire")); + std::thread::sleep(Duration::from_millis(100)); + assert!(!checker.check_handshake(b"expire")); + } + + #[test] + fn test_replay_checker_stats() { + let checker = ReplayChecker::new(100, Duration::from_secs(60)); + checker.add_handshake(b"k1"); + checker.add_handshake(b"k2"); + checker.check_handshake(b"k1"); + checker.check_handshake(b"k3"); + let stats = checker.stats(); + assert_eq!(stats.total_additions, 2); + assert_eq!(stats.total_checks, 2); + assert_eq!(stats.total_hits, 1); + } + + #[test] + fn test_replay_checker_many_keys() { + let checker = ReplayChecker::new(1000, Duration::from_secs(60)); + for i in 0..500u32 { + checker.add(&i.to_le_bytes()); + } + for i in 0..500u32 { + assert!(checker.check(&i.to_le_bytes())); + } + assert_eq!(checker.stats().total_entries, 500); } } \ No newline at end of file diff --git a/src/stream/frame.rs b/src/stream/frame.rs index 42f6c74..b97d4cf 100644 --- a/src/stream/frame.rs +++ b/src/stream/frame.rs @@ -5,8 +5,10 @@ use bytes::{Bytes, BytesMut}; use std::io::Result; +use std::sync::Arc; use crate::protocol::constants::ProtoTag; +use crate::crypto::SecureRandom; // ============= Frame Types ============= @@ -147,11 +149,11 @@ pub trait FrameCodec: Send + Sync { // ============= Codec Factory ============= /// Create a frame codec for the given protocol tag -pub fn create_codec(proto_tag: ProtoTag) -> Box { +pub fn create_codec(proto_tag: ProtoTag, rng: Arc) -> Box { match proto_tag { ProtoTag::Abridged => Box::new(crate::stream::frame_codec::AbridgedCodec::new()), ProtoTag::Intermediate => Box::new(crate::stream::frame_codec::IntermediateCodec::new()), - ProtoTag::Secure => Box::new(crate::stream::frame_codec::SecureCodec::new()), + ProtoTag::Secure => Box::new(crate::stream::frame_codec::SecureCodec::new(rng)), } } diff --git a/src/stream/frame_codec.rs b/src/stream/frame_codec.rs index 75f6bde..30bcc95 100644 --- a/src/stream/frame_codec.rs +++ b/src/stream/frame_codec.rs @@ -5,9 +5,11 @@ use bytes::{Bytes, BytesMut, BufMut}; use std::io::{self, Error, ErrorKind}; +use std::sync::Arc; use tokio_util::codec::{Decoder, Encoder}; use crate::protocol::constants::ProtoTag; +use crate::crypto::SecureRandom; use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait}; // ============= Unified Codec ============= @@ -21,14 +23,17 @@ pub struct FrameCodec { proto_tag: ProtoTag, /// Maximum allowed frame size max_frame_size: usize, + /// RNG for secure padding + rng: Arc, } impl FrameCodec { /// Create a new codec for the given protocol - pub fn new(proto_tag: ProtoTag) -> Self { + pub fn new(proto_tag: ProtoTag, rng: Arc) -> Self { Self { proto_tag, max_frame_size: 16 * 1024 * 1024, // 16MB default + rng, } } @@ -64,7 +69,7 @@ impl Encoder for FrameCodec { match self.proto_tag { ProtoTag::Abridged => encode_abridged(&frame, dst), ProtoTag::Intermediate => encode_intermediate(&frame, dst), - ProtoTag::Secure => encode_secure(&frame, dst), + ProtoTag::Secure => encode_secure(&frame, dst, &self.rng), } } } @@ -288,9 +293,7 @@ fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result io::Result<()> { - use crate::crypto::random::SECURE_RANDOM; - +fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::Result<()> { let data = &frame.data; // Simple ACK: just send data @@ -303,10 +306,10 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> { // Generate padding to make length not divisible by 4 let padding_len = if data.len() % 4 == 0 { // Add 1-3 bytes to make it non-aligned - (SECURE_RANDOM.range(3) + 1) as usize + (rng.range(3) + 1) as usize } else { // Already non-aligned, can add 0-3 - SECURE_RANDOM.range(4) as usize + rng.range(4) as usize }; let total_len = data.len() + padding_len; @@ -321,7 +324,7 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> { dst.extend_from_slice(data); if padding_len > 0 { - let padding = SECURE_RANDOM.bytes(padding_len); + let padding = rng.bytes(padding_len); dst.extend_from_slice(&padding); } @@ -445,19 +448,21 @@ impl FrameCodecTrait for IntermediateCodec { /// Secure Intermediate protocol codec pub struct SecureCodec { max_frame_size: usize, + rng: Arc, } impl SecureCodec { - pub fn new() -> Self { + pub fn new(rng: Arc) -> Self { Self { max_frame_size: 16 * 1024 * 1024, + rng, } } } impl Default for SecureCodec { fn default() -> Self { - Self::new() + Self::new(Arc::new(SecureRandom::new())) } } @@ -474,7 +479,7 @@ impl Encoder for SecureCodec { type Error = io::Error; fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> { - encode_secure(&frame, dst) + encode_secure(&frame, dst, &self.rng) } } @@ -485,7 +490,7 @@ impl FrameCodecTrait for SecureCodec { fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result { let before = dst.len(); - encode_secure(frame, dst)?; + encode_secure(frame, dst, &self.rng)?; Ok(dst.len() - before) } @@ -506,6 +511,8 @@ mod tests { use tokio_util::codec::{FramedRead, FramedWrite}; use tokio::io::duplex; use futures::{SinkExt, StreamExt}; + use crate::crypto::SecureRandom; + use std::sync::Arc; #[tokio::test] async fn test_framed_abridged() { @@ -541,8 +548,8 @@ mod tests { async fn test_framed_secure() { let (client, server) = duplex(4096); - let mut writer = FramedWrite::new(client, SecureCodec::new()); - let mut reader = FramedRead::new(server, SecureCodec::new()); + let mut writer = FramedWrite::new(client, SecureCodec::new(Arc::new(SecureRandom::new()))); + let mut reader = FramedRead::new(server, SecureCodec::new(Arc::new(SecureRandom::new()))); let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]); let frame = Frame::new(original.clone()); @@ -557,8 +564,8 @@ mod tests { for proto_tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] { let (client, server) = duplex(4096); - let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag)); - let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag)); + let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new()))); + let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new()))); // Use 4-byte aligned data for abridged compatibility let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]); @@ -607,7 +614,7 @@ mod tests { #[test] fn test_frame_too_large() { - let mut codec = FrameCodec::new(ProtoTag::Intermediate) + let mut codec = FrameCodec::new(ProtoTag::Intermediate, Arc::new(SecureRandom::new())) .with_max_frame_size(100); // Create a "frame" that claims to be very large diff --git a/src/stream/frame_stream.rs b/src/stream/frame_stream.rs index 9e62c8d..fd8c1b4 100644 --- a/src/stream/frame_stream.rs +++ b/src/stream/frame_stream.rs @@ -4,8 +4,8 @@ use bytes::{Bytes, BytesMut}; use std::io::{Error, ErrorKind, Result}; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; use crate::protocol::constants::*; -use crate::crypto::crc32; -use crate::crypto::random::SECURE_RANDOM; +use crate::crypto::{crc32, SecureRandom}; +use std::sync::Arc; use super::traits::{FrameMeta, LayeredStream}; // ============= Abridged (Compact) Frame ============= @@ -251,11 +251,12 @@ impl LayeredStream for SecureIntermediateFrameReader { /// Writer for secure intermediate MTProto framing pub struct SecureIntermediateFrameWriter { upstream: W, + rng: Arc, } impl SecureIntermediateFrameWriter { - pub fn new(upstream: W) -> Self { - Self { upstream } + pub fn new(upstream: W, rng: Arc) -> Self { + Self { upstream, rng } } } @@ -267,8 +268,8 @@ impl SecureIntermediateFrameWriter { } // Add random padding (0-3 bytes) - let padding_len = SECURE_RANDOM.range(4); - let padding = SECURE_RANDOM.bytes(padding_len); + let padding_len = self.rng.range(4); + let padding = self.rng.bytes(padding_len); let total_len = data.len() + padding_len; let len_bytes = (total_len as u32).to_le_bytes(); @@ -454,11 +455,11 @@ pub enum FrameWriterKind { } impl FrameWriterKind { - pub fn new(upstream: W, proto_tag: ProtoTag) -> Self { + pub fn new(upstream: W, proto_tag: ProtoTag, rng: Arc) -> Self { match proto_tag { ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)), ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)), - ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream)), + ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream, rng)), } } @@ -483,6 +484,8 @@ impl FrameWriterKind { mod tests { use super::*; use tokio::io::duplex; + use std::sync::Arc; + use crate::crypto::SecureRandom; #[tokio::test] async fn test_abridged_roundtrip() { @@ -539,7 +542,7 @@ mod tests { async fn test_secure_intermediate_padding() { let (client, server) = duplex(1024); - let mut writer = SecureIntermediateFrameWriter::new(client); + let mut writer = SecureIntermediateFrameWriter::new(client, Arc::new(SecureRandom::new())); let mut reader = SecureIntermediateFrameReader::new(server); let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8]; @@ -572,7 +575,7 @@ mod tests { async fn test_frame_reader_kind() { let (client, server) = duplex(1024); - let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate); + let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate, Arc::new(SecureRandom::new())); let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate); let data = vec![1u8, 2, 3, 4]; diff --git a/src/transport/mod.rs b/src/transport/mod.rs index bbc5302..2b507d5 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -10,4 +10,4 @@ pub use pool::ConnectionPool; pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol}; pub use socket::*; pub use socks::*; -pub use upstream::UpstreamManager; \ No newline at end of file +pub use upstream::{UpstreamManager, StartupPingResult, DcPingResult}; \ No newline at end of file diff --git a/src/transport/upstream.rs b/src/transport/upstream.rs index 4a68830..86e6b2a 100644 --- a/src/transport/upstream.rs +++ b/src/transport/upstream.rs @@ -1,26 +1,113 @@ -//! Upstream Management +//! Upstream Management with per-DC latency-weighted selection use std::net::{SocketAddr, IpAddr}; use std::sync::Arc; use std::time::Duration; use tokio::net::TcpStream; use tokio::sync::RwLock; +use tokio::time::Instant; use rand::Rng; -use tracing::{debug, warn, error, info}; +use tracing::{debug, warn, info, trace}; use crate::config::{UpstreamConfig, UpstreamType}; use crate::error::{Result, ProxyError}; +use crate::protocol::constants::{TG_DATACENTERS_V4, TG_DATACENTERS_V6, TG_DATACENTER_PORT}; use crate::transport::socket::create_outgoing_socket_bound; use crate::transport::socks::{connect_socks4, connect_socks5}; +/// Number of Telegram datacenters +const NUM_DCS: usize = 5; + +// ============= RTT Tracking ============= + +#[derive(Debug, Clone, Copy)] +struct LatencyEma { + value_ms: Option, + alpha: f64, +} + +impl LatencyEma { + const fn new(alpha: f64) -> Self { + Self { value_ms: None, alpha } + } + + fn update(&mut self, sample_ms: f64) { + self.value_ms = Some(match self.value_ms { + None => sample_ms, + Some(prev) => prev * (1.0 - self.alpha) + sample_ms * self.alpha, + }); + } + + fn get(&self) -> Option { + self.value_ms + } +} + +// ============= Upstream State ============= + #[derive(Debug)] struct UpstreamState { config: UpstreamConfig, healthy: bool, fails: u32, last_check: std::time::Instant, + /// Per-DC latency EMA (index 0 = DC1, index 4 = DC5) + dc_latency: [LatencyEma; NUM_DCS], } +impl UpstreamState { + fn new(config: UpstreamConfig) -> Self { + Self { + config, + healthy: true, + fails: 0, + last_check: std::time::Instant::now(), + dc_latency: [LatencyEma::new(0.3); NUM_DCS], + } + } + + /// Convert dc_idx (1-based, may be negative) to array index 0..4 + fn dc_array_idx(dc_idx: i16) -> Option { + let idx = (dc_idx.unsigned_abs() as usize).checked_sub(1)?; + if idx < NUM_DCS { Some(idx) } else { None } + } + + /// Get latency for a specific DC, falling back to average across all known DCs + fn effective_latency(&self, dc_idx: Option) -> Option { + // Try DC-specific latency first + if let Some(di) = dc_idx.and_then(Self::dc_array_idx) { + if let Some(ms) = self.dc_latency[di].get() { + return Some(ms); + } + } + + // Fallback: average of all known DC latencies + let (sum, count) = self.dc_latency.iter() + .filter_map(|l| l.get()) + .fold((0.0, 0u32), |(s, c), v| (s + v, c + 1)); + + if count > 0 { Some(sum / count as f64) } else { None } + } +} + +/// Result of a single DC ping +#[derive(Debug, Clone)] +pub struct DcPingResult { + pub dc_idx: usize, + pub dc_addr: SocketAddr, + pub rtt_ms: Option, + pub error: Option, +} + +/// Result of startup ping for one upstream +#[derive(Debug, Clone)] +pub struct StartupPingResult { + pub results: Vec, + pub upstream_name: String, +} + +// ============= Upstream Manager ============= + #[derive(Clone)] pub struct UpstreamManager { upstreams: Arc>>, @@ -30,12 +117,7 @@ impl UpstreamManager { pub fn new(configs: Vec) -> Self { let states = configs.into_iter() .filter(|c| c.enabled) - .map(|c| UpstreamState { - config: c, - healthy: true, // Optimistic start - fails: 0, - last_check: std::time::Instant::now(), - }) + .map(UpstreamState::new) .collect(); Self { @@ -43,48 +125,78 @@ impl UpstreamManager { } } - /// Select an upstream using Weighted Round Robin (simplified) - async fn select_upstream(&self) -> Option { + /// Select upstream using latency-weighted random selection. + /// + /// `effective_weight = config_weight × latency_factor` + /// + /// where `latency_factor = 1000 / latency_ms` if latency is known, + /// or `1.0` if no latency data is available. + /// + /// This means a 50ms upstream gets factor 20, a 200ms upstream gets + /// factor 5 — the faster route is 4× more likely to be chosen + /// (all else being equal). + async fn select_upstream(&self, dc_idx: Option) -> Option { let upstreams = self.upstreams.read().await; if upstreams.is_empty() { return None; } - let healthy_indices: Vec = upstreams.iter() + let healthy: Vec = upstreams.iter() .enumerate() .filter(|(_, u)| u.healthy) .map(|(i, _)| i) .collect(); - if healthy_indices.is_empty() { - // If all unhealthy, try any random one - return Some(rand::thread_rng().gen_range(0..upstreams.len())); + if healthy.is_empty() { + // All unhealthy — pick any + return Some(rand::rng().gen_range(0..upstreams.len())); } - // Weighted selection - let total_weight: u32 = healthy_indices.iter() - .map(|&i| upstreams[i].config.weight as u32) - .sum(); + if healthy.len() == 1 { + return Some(healthy[0]); + } + + // Calculate latency-weighted scores + let weights: Vec<(usize, f64)> = healthy.iter().map(|&i| { + let base = upstreams[i].config.weight as f64; + let latency_factor = upstreams[i].effective_latency(dc_idx) + .map(|ms| if ms > 1.0 { 1000.0 / ms } else { 1000.0 }) + .unwrap_or(1.0); - if total_weight == 0 { - return Some(healthy_indices[rand::thread_rng().gen_range(0..healthy_indices.len())]); + (i, base * latency_factor) + }).collect(); + + let total: f64 = weights.iter().map(|(_, w)| w).sum(); + + if total <= 0.0 { + return Some(healthy[rand::rng().gen_range(0..healthy.len())]); } - let mut choice = rand::thread_rng().gen_range(0..total_weight); + let mut choice: f64 = rand::rng().gen_range(0.0..total); - for &idx in &healthy_indices { - let weight = upstreams[idx].config.weight as u32; + for &(idx, weight) in &weights { if choice < weight { + trace!( + upstream = idx, + dc = ?dc_idx, + weight = format!("{:.2}", weight), + total = format!("{:.2}", total), + "Upstream selected" + ); return Some(idx); } choice -= weight; } - Some(healthy_indices[0]) + Some(healthy[0]) } - pub async fn connect(&self, target: SocketAddr) -> Result { - let idx = self.select_upstream().await + /// Connect to target through a selected upstream. + /// + /// `dc_idx` is used for latency-based upstream selection and RTT tracking. + /// Pass `None` if DC index is unknown. + pub async fn connect(&self, target: SocketAddr, dc_idx: Option) -> Result { + let idx = self.select_upstream(dc_idx).await .ok_or_else(|| ProxyError::Config("No upstreams available".to_string()))?; let upstream = { @@ -92,28 +204,34 @@ impl UpstreamManager { guard[idx].config.clone() }; + let start = Instant::now(); + match self.connect_via_upstream(&upstream, target).await { Ok(stream) => { - // Mark success + let rtt_ms = start.elapsed().as_secs_f64() * 1000.0; let mut guard = self.upstreams.write().await; if let Some(u) = guard.get_mut(idx) { if !u.healthy { - debug!("Upstream recovered: {:?}", u.config); + debug!(rtt_ms = format!("{:.1}", rtt_ms), "Upstream recovered"); } u.healthy = true; u.fails = 0; + + // Store per-DC latency + if let Some(di) = dc_idx.and_then(UpstreamState::dc_array_idx) { + u.dc_latency[di].update(rtt_ms); + } } Ok(stream) }, Err(e) => { - // Mark failure let mut guard = self.upstreams.write().await; if let Some(u) = guard.get_mut(idx) { u.fails += 1; - warn!("Failed to connect via upstream {:?}: {}. Fails: {}", u.config, e, u.fails); + warn!(fails = u.fails, "Upstream failed: {}", e); if u.fails > 3 { u.healthy = false; - warn!("Upstream disabled due to failures: {:?}", u.config); + warn!("Upstream marked unhealthy"); } } Err(e) @@ -129,18 +247,16 @@ impl UpstreamManager { let socket = create_outgoing_socket_bound(target, bind_ip)?; - // Non-blocking connect logic socket.set_nonblocking(true)?; match socket.connect(&target.into()) { Ok(()) => {}, - Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {}, + Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {}, Err(err) => return Err(ProxyError::Io(err)), } let std_stream: std::net::TcpStream = socket.into(); let stream = TcpStream::from_std(std_stream)?; - // Wait for connection to complete stream.writable().await?; if let Some(e) = stream.take_error()? { return Err(ProxyError::Io(e)); @@ -149,8 +265,6 @@ impl UpstreamManager { Ok(stream) }, UpstreamType::Socks4 { address, interface, user_id } => { - info!("Connecting to target {} via SOCKS4 proxy {}", target, address); - let proxy_addr: SocketAddr = address.parse() .map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?; @@ -159,18 +273,16 @@ impl UpstreamManager { let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?; - // Non-blocking connect logic socket.set_nonblocking(true)?; match socket.connect(&proxy_addr.into()) { Ok(()) => {}, - Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {}, + Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {}, Err(err) => return Err(ProxyError::Io(err)), } let std_stream: std::net::TcpStream = socket.into(); let mut stream = TcpStream::from_std(std_stream)?; - // Wait for connection to complete stream.writable().await?; if let Some(e) = stream.take_error()? { return Err(ProxyError::Io(e)); @@ -180,8 +292,6 @@ impl UpstreamManager { Ok(stream) }, UpstreamType::Socks5 { address, interface, username, password } => { - info!("Connecting to target {} via SOCKS5 proxy {}", target, address); - let proxy_addr: SocketAddr = address.parse() .map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?; @@ -190,18 +300,16 @@ impl UpstreamManager { let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?; - // Non-blocking connect logic socket.set_nonblocking(true)?; match socket.connect(&proxy_addr.into()) { Ok(()) => {}, - Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {}, + Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {}, Err(err) => return Err(ProxyError::Io(err)), } let std_stream: std::net::TcpStream = socket.into(); let mut stream = TcpStream::from_std(std_stream)?; - // Wait for connection to complete stream.writable().await?; if let Some(e) = stream.take_error()? { return Err(ProxyError::Io(e)); @@ -213,13 +321,100 @@ impl UpstreamManager { } } - /// Background task to check health - pub async fn run_health_checks(&self) { - // Simple TCP connect check to a known stable DC (e.g. 149.154.167.50:443 - DC2) - let check_target: SocketAddr = "149.154.167.50:443".parse().unwrap(); + // ============= Startup Ping ============= + + /// Ping all Telegram DCs through all upstreams. + pub async fn ping_all_dcs(&self, prefer_ipv6: bool) -> Vec { + let upstreams: Vec<(usize, UpstreamConfig)> = { + let guard = self.upstreams.read().await; + guard.iter().enumerate() + .map(|(i, u)| (i, u.config.clone())) + .collect() + }; + + let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 }; + + let mut all_results = Vec::new(); + + for (upstream_idx, upstream_config) in &upstreams { + let upstream_name = match &upstream_config.upstream_type { + UpstreamType::Direct { interface } => { + format!("direct{}", interface.as_ref().map(|i| format!(" ({})", i)).unwrap_or_default()) + } + UpstreamType::Socks4 { address, .. } => format!("socks4://{}", address), + UpstreamType::Socks5 { address, .. } => format!("socks5://{}", address), + }; + + let mut dc_results = Vec::new(); + + for (dc_zero_idx, dc_ip) in datacenters.iter().enumerate() { + let dc_addr = SocketAddr::new(*dc_ip, TG_DATACENTER_PORT); + + let ping_result = tokio::time::timeout( + Duration::from_secs(5), + self.ping_single_dc(upstream_config, dc_addr) + ).await; + + let result = match ping_result { + Ok(Ok(rtt_ms)) => { + // Store per-DC latency + let mut guard = self.upstreams.write().await; + if let Some(u) = guard.get_mut(*upstream_idx) { + u.dc_latency[dc_zero_idx].update(rtt_ms); + } + DcPingResult { + dc_idx: dc_zero_idx + 1, + dc_addr, + rtt_ms: Some(rtt_ms), + error: None, + } + } + Ok(Err(e)) => DcPingResult { + dc_idx: dc_zero_idx + 1, + dc_addr, + rtt_ms: None, + error: Some(e.to_string()), + }, + Err(_) => DcPingResult { + dc_idx: dc_zero_idx + 1, + dc_addr, + rtt_ms: None, + error: Some("timeout (5s)".to_string()), + }, + }; + + dc_results.push(result); + } + + all_results.push(StartupPingResult { + results: dc_results, + upstream_name, + }); + } + + all_results + } + + async fn ping_single_dc(&self, config: &UpstreamConfig, target: SocketAddr) -> Result { + let start = Instant::now(); + let _stream = self.connect_via_upstream(config, target).await?; + Ok(start.elapsed().as_secs_f64() * 1000.0) + } + + // ============= Health Checks ============= + + /// Background health check: rotates through DCs, 30s interval. + pub async fn run_health_checks(&self, prefer_ipv6: bool) { + let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 }; + let mut dc_rotation = 0usize; loop { - tokio::time::sleep(Duration::from_secs(60)).await; + tokio::time::sleep(Duration::from_secs(30)).await; + + let dc_zero_idx = dc_rotation % datacenters.len(); + dc_rotation += 1; + + let check_target = SocketAddr::new(datacenters[dc_zero_idx], TG_DATACENTER_PORT); let count = self.upstreams.read().await.len(); for i in 0..count { @@ -228,6 +423,7 @@ impl UpstreamManager { guard[i].config.clone() }; + let start = Instant::now(); let result = tokio::time::timeout( Duration::from_secs(10), self.connect_via_upstream(&config, check_target) @@ -238,18 +434,36 @@ impl UpstreamManager { match result { Ok(Ok(_stream)) => { + let rtt_ms = start.elapsed().as_secs_f64() * 1000.0; + u.dc_latency[dc_zero_idx].update(rtt_ms); + if !u.healthy { - debug!("Upstream recovered: {:?}", u.config); + info!( + rtt = format!("{:.0}ms", rtt_ms), + dc = dc_zero_idx + 1, + "Upstream recovered" + ); } u.healthy = true; u.fails = 0; } Ok(Err(e)) => { - debug!("Health check failed for {:?}: {}", u.config, e); - // Don't mark unhealthy immediately in background check + u.fails += 1; + debug!(dc = dc_zero_idx + 1, fails = u.fails, + "Health check failed: {}", e); + if u.fails > 3 { + u.healthy = false; + warn!("Upstream unhealthy (fails)"); + } } Err(_) => { - debug!("Health check timeout for {:?}", u.config); + u.fails += 1; + debug!(dc = dc_zero_idx + 1, fails = u.fails, + "Health check timeout"); + if u.fails > 3 { + u.healthy = false; + warn!("Upstream unhealthy (timeout)"); + } } } u.last_check = std::time::Instant::now();