diff --git a/config.toml b/config.toml index 92e7249..bd0f90c 100644 --- a/config.toml +++ b/config.toml @@ -44,7 +44,7 @@ client_ack = 300 # === Anti-Censorship & Masking === [censorship] -tls_domain = "google.ru" +tls_domain = "petrovich.ru" mask = true mask_port = 443 # mask_host = "petrovich.ru" # Defaults to tls_domain if not set @@ -75,6 +75,6 @@ weight = 10 # [[upstreams]] # type = "socks5" -# address = "127.0.0.1:9050" +# address = "127.0.0.1:1080" # enabled = false # weight = 1 diff --git a/src/config/mod.rs b/src/config/mod.rs index 887973e..ec95036 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,32 +1,55 @@ //! Configuration +use crate::error::{ProxyError, Result}; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::net::{IpAddr, SocketAddr}; use std::path::Path; -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; -use crate::error::{ProxyError, Result}; // ============= Helper Defaults ============= -fn default_true() -> bool { true } -fn default_port() -> u16 { 443 } -fn default_tls_domain() -> String { "www.google.com".to_string() } -fn default_mask_port() -> u16 { 443 } -fn default_replay_check_len() -> usize { 65536 } -fn default_replay_window_secs() -> u64 { 1800 } -fn default_handshake_timeout() -> u64 { 15 } -fn default_connect_timeout() -> u64 { 10 } -fn default_keepalive() -> u64 { 60 } -fn default_ack_timeout() -> u64 { 300 } -fn default_listen_addr() -> String { "0.0.0.0".to_string() } -fn default_fake_cert_len() -> usize { 2048 } -fn default_weight() -> u16 { 1 } +fn default_true() -> bool { + true +} +fn default_port() -> u16 { + 443 +} +fn default_tls_domain() -> String { + "www.google.com".to_string() +} +fn default_mask_port() -> u16 { + 443 +} +fn default_replay_check_len() -> usize { + 65536 +} +fn default_replay_window_secs() -> u64 { + 1800 +} +fn default_handshake_timeout() -> u64 { + 15 +} +fn default_connect_timeout() -> u64 { + 10 +} +fn default_keepalive() -> u64 { + 60 +} +fn default_ack_timeout() -> u64 { + 300 +} +fn default_listen_addr() -> String { + "0.0.0.0".to_string() +} +fn default_fake_cert_len() -> usize { + 2048 +} +fn default_weight() -> u16 { + 1 +} fn default_metrics_whitelist() -> Vec { - vec![ - "127.0.0.1".parse().unwrap(), - "::1".parse().unwrap(), - ] + vec!["127.0.0.1".parse().unwrap(), "::1".parse().unwrap()] } // ============= Log Level ============= @@ -96,7 +119,11 @@ pub struct ProxyModes { impl Default for ProxyModes { fn default() -> Self { - Self { classic: true, secure: true, tls: true } + Self { + classic: true, + secure: true, + tls: true, + } } } @@ -104,19 +131,37 @@ impl Default for ProxyModes { pub struct GeneralConfig { #[serde(default)] pub modes: ProxyModes, - + #[serde(default)] pub prefer_ipv6: bool, - + #[serde(default = "default_true")] pub fast_mode: bool, - + #[serde(default)] pub use_middle_proxy: bool, #[serde(default)] pub ad_tag: Option, - + + /// Path to proxy-secret binary file (auto-downloaded if absent). + /// Infrastructure secret from https://core.telegram.org/getProxySecret + #[serde(default)] + pub proxy_secret_path: Option, + + /// Public IP override for middle-proxy NAT environments. + /// When set, this IP is used in ME key derivation and RPC_PROXY_REQ "our_addr". + #[serde(default)] + pub middle_proxy_nat_ip: Option, + + /// Enable STUN-based NAT probing to discover public IP:port for ME KDF. + #[serde(default)] + pub middle_proxy_nat_probe: bool, + + /// Optional STUN server address (host:port) for NAT probing. + #[serde(default)] + pub middle_proxy_nat_stun: Option, + #[serde(default)] pub log_level: LogLevel, } @@ -129,6 +174,10 @@ impl Default for GeneralConfig { fast_mode: true, use_middle_proxy: false, ad_tag: None, + proxy_secret_path: None, + middle_proxy_nat_ip: None, + middle_proxy_nat_probe: false, + middle_proxy_nat_stun: None, log_level: LogLevel::Normal, } } @@ -141,16 +190,16 @@ pub struct ServerConfig { #[serde(default = "default_listen_addr")] pub listen_addr_ipv4: String, - + #[serde(default)] pub listen_addr_ipv6: Option, - + #[serde(default)] pub listen_unix_sock: Option, - + #[serde(default)] pub metrics_port: Option, - + #[serde(default = "default_metrics_whitelist")] pub metrics_whitelist: Vec, @@ -176,13 +225,13 @@ impl Default for ServerConfig { pub struct TimeoutsConfig { #[serde(default = "default_handshake_timeout")] pub client_handshake: u64, - + #[serde(default = "default_connect_timeout")] pub tg_connect: u64, - + #[serde(default = "default_keepalive")] pub client_keepalive: u64, - + #[serde(default = "default_ack_timeout")] pub client_ack: u64, } @@ -202,13 +251,13 @@ impl Default for TimeoutsConfig { pub struct AntiCensorshipConfig { #[serde(default = "default_tls_domain")] pub tls_domain: String, - + #[serde(default = "default_true")] pub mask: bool, - + #[serde(default)] pub mask_host: Option, - + #[serde(default = "default_mask_port")] pub mask_port: u16, @@ -239,19 +288,19 @@ pub struct AccessConfig { #[serde(default)] pub user_max_tcp_conns: HashMap, - + #[serde(default)] pub user_expirations: HashMap>, - + #[serde(default)] pub user_data_quota: HashMap, #[serde(default = "default_replay_check_len")] pub replay_check_len: usize, - + #[serde(default = "default_replay_window_secs")] pub replay_window_secs: u64, - + #[serde(default)] pub ignore_time_skew: bool, } @@ -259,7 +308,10 @@ pub struct AccessConfig { impl Default for AccessConfig { fn default() -> Self { let mut users = HashMap::new(); - users.insert("default".to_string(), "00000000000000000000000000000000".to_string()); + users.insert( + "default".to_string(), + "00000000000000000000000000000000".to_string(), + ); Self { users, user_max_tcp_conns: HashMap::new(), @@ -454,12 +506,12 @@ pub struct ProxyConfig { impl ProxyConfig { pub fn load>(path: P) -> Result { - let content = std::fs::read_to_string(path) - .map_err(|e| ProxyError::Config(e.to_string()))?; - - let mut config: ProxyConfig = toml::from_str(&content) - .map_err(|e| ProxyError::Config(e.to_string()))?; - + let content = + std::fs::read_to_string(path).map_err(|e| ProxyError::Config(e.to_string()))?; + + let mut config: ProxyConfig = + toml::from_str(&content).map_err(|e| ProxyError::Config(e.to_string()))?; + // Validate secrets for (user, secret) in &config.access.users { if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 { @@ -469,33 +521,34 @@ impl ProxyConfig { }); } } - + // Validate tls_domain if config.censorship.tls_domain.is_empty() { return Err(ProxyError::Config("tls_domain cannot be empty".to_string())); } - + // Validate mask_unix_sock if let Some(ref sock_path) = config.censorship.mask_unix_sock { if sock_path.is_empty() { return Err(ProxyError::Config( - "mask_unix_sock cannot be empty".to_string() + "mask_unix_sock cannot be empty".to_string(), )); } #[cfg(unix)] if sock_path.len() > 107 { - return Err(ProxyError::Config( - format!("mask_unix_sock path too long: {} bytes (max 107)", sock_path.len()) - )); + return Err(ProxyError::Config(format!( + "mask_unix_sock path too long: {} bytes (max 107)", + sock_path.len() + ))); } #[cfg(not(unix))] return Err(ProxyError::Config( - "mask_unix_sock is only supported on Unix platforms".to_string() + "mask_unix_sock is only supported on Unix platforms".to_string(), )); if config.censorship.mask_host.is_some() { return Err(ProxyError::Config( - "mask_unix_sock and mask_host are mutually exclusive".to_string() + "mask_unix_sock and mask_host are mutually exclusive".to_string(), )); } } @@ -504,11 +557,11 @@ impl ProxyConfig { if config.censorship.mask_host.is_none() && config.censorship.mask_unix_sock.is_none() { config.censorship.mask_host = Some(config.censorship.tls_domain.clone()); } - + // Random fake_cert_len use rand::Rng; config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096); - + // Migration: Populate listeners if empty if config.server.listeners.is_empty() { if let Ok(ipv4) = config.server.listen_addr_ipv4.parse::() { @@ -518,7 +571,7 @@ impl ProxyConfig { }); } if let Some(ipv6_str) = &config.server.listen_addr_ipv6 { - if let Ok(ipv6) = ipv6_str.parse::() { + if let Ok(ipv6) = ipv6_str.parse::() { config.server.listeners.push(ListenerConfig { ip: ipv6, announce_ip: None, @@ -529,31 +582,32 @@ impl ProxyConfig { // Migration: Populate upstreams if empty (Default Direct) if config.upstreams.is_empty() { - config.upstreams.push(UpstreamConfig { + config.upstreams.push(UpstreamConfig { upstream_type: UpstreamType::Direct { interface: None }, weight: 1, enabled: true, }); } - + Ok(config) } - + pub fn validate(&self) -> Result<()> { if self.access.users.is_empty() { return Err(ProxyError::Config("No users configured".to_string())); } - + if !self.general.modes.classic && !self.general.modes.secure && !self.general.modes.tls { return Err(ProxyError::Config("No modes enabled".to_string())); } - + if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') { - return Err(ProxyError::Config( - format!("Invalid tls_domain: '{}'. Must be a valid domain name", self.censorship.tls_domain) - )); + return Err(ProxyError::Config(format!( + "Invalid tls_domain: '{}'. Must be a valid domain name", + self.censorship.tls_domain + ))); } Ok(()) } -} \ No newline at end of file +} diff --git a/src/crypto/hash.rs b/src/crypto/hash.rs index cf3ba0d..6b3c654 100644 --- a/src/crypto/hash.rs +++ b/src/crypto/hash.rs @@ -55,6 +55,49 @@ pub fn crc32(data: &[u8]) -> u32 { crc32fast::hash(data) } +/// Build the exact prekey buffer used by Telegram Middle Proxy KDF. +/// +/// Returned buffer layout (IPv4): +/// nonce_srv | nonce_clt | clt_ts | srv_ip | clt_port | purpose | clt_ip | srv_port | secret | nonce_srv | [clt_v6 | srv_v6] | nonce_clt +pub fn build_middleproxy_prekey( + nonce_srv: &[u8; 16], + nonce_clt: &[u8; 16], + clt_ts: &[u8; 4], + srv_ip: Option<&[u8]>, + clt_port: &[u8; 2], + purpose: &[u8], + clt_ip: Option<&[u8]>, + srv_port: &[u8; 2], + secret: &[u8], + clt_ipv6: Option<&[u8; 16]>, + srv_ipv6: Option<&[u8; 16]>, +) -> Vec { + const EMPTY_IP: [u8; 4] = [0, 0, 0, 0]; + + let srv_ip = srv_ip.unwrap_or(&EMPTY_IP); + let clt_ip = clt_ip.unwrap_or(&EMPTY_IP); + + let mut s = Vec::with_capacity(256); + s.extend_from_slice(nonce_srv); + s.extend_from_slice(nonce_clt); + s.extend_from_slice(clt_ts); + s.extend_from_slice(srv_ip); + s.extend_from_slice(clt_port); + s.extend_from_slice(purpose); + s.extend_from_slice(clt_ip); + s.extend_from_slice(srv_port); + s.extend_from_slice(secret); + s.extend_from_slice(nonce_srv); + + if let (Some(clt_v6), Some(srv_v6)) = (clt_ipv6, srv_ipv6) { + s.extend_from_slice(clt_v6); + s.extend_from_slice(srv_v6); + } + + s.extend_from_slice(nonce_clt); + s +} + /// Middle Proxy key derivation /// /// Uses MD5 + SHA-1 as mandated by the Telegram Middle Proxy protocol. @@ -73,30 +116,20 @@ pub fn derive_middleproxy_keys( clt_ipv6: Option<&[u8; 16]>, srv_ipv6: Option<&[u8; 16]>, ) -> ([u8; 32], [u8; 16]) { - const EMPTY_IP: [u8; 4] = [0, 0, 0, 0]; - - let srv_ip = srv_ip.unwrap_or(&EMPTY_IP); - let clt_ip = clt_ip.unwrap_or(&EMPTY_IP); - - let mut s = Vec::with_capacity(256); - s.extend_from_slice(nonce_srv); - s.extend_from_slice(nonce_clt); - s.extend_from_slice(clt_ts); - s.extend_from_slice(srv_ip); - s.extend_from_slice(clt_port); - s.extend_from_slice(purpose); - s.extend_from_slice(clt_ip); - s.extend_from_slice(srv_port); - s.extend_from_slice(secret); - s.extend_from_slice(nonce_srv); - - if let (Some(clt_v6), Some(srv_v6)) = (clt_ipv6, srv_ipv6) { - s.extend_from_slice(clt_v6); - s.extend_from_slice(srv_v6); - } - - s.extend_from_slice(nonce_clt); - + let s = build_middleproxy_prekey( + nonce_srv, + nonce_clt, + clt_ts, + srv_ip, + clt_port, + purpose, + clt_ip, + srv_port, + secret, + clt_ipv6, + srv_ipv6, + ); + let md5_1 = md5(&s[1..]); let sha1_sum = sha1(&s); let md5_2 = md5(&s[2..]); @@ -106,4 +139,40 @@ pub fn derive_middleproxy_keys( key[12..].copy_from_slice(&sha1_sum); (key, md5_2) -} \ No newline at end of file +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn middleproxy_prekey_sha_is_stable() { + let nonce_srv = [0x11u8; 16]; + let nonce_clt = [0x22u8; 16]; + let clt_ts = 0x44332211u32.to_le_bytes(); + let srv_ip = Some([149u8, 154, 175, 50].as_ref()); + let clt_ip = Some([10u8, 0, 0, 1].as_ref()); + let clt_port = 0x1f90u16.to_le_bytes(); // 8080 + let srv_port = 0x22b8u16.to_le_bytes(); // 8888 + let secret = vec![0x55u8; 128]; + + let prekey = build_middleproxy_prekey( + &nonce_srv, + &nonce_clt, + &clt_ts, + srv_ip, + &clt_port, + b"CLIENT", + clt_ip, + &srv_port, + &secret, + None, + None, + ); + let digest = sha256(&prekey); + assert_eq!( + hex::encode(digest), + "a4595b75f1f610f2575ace802ddc65c91b5acef3b0e0d18189e0c7c9f787d15c" + ); + } +} diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index dfc2be6..40951c6 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -5,5 +5,5 @@ pub mod hash; pub mod random; pub use aes::{AesCtr, AesCbc}; -pub use hash::{sha256, sha256_hmac, sha1, md5, crc32}; -pub use random::SecureRandom; \ No newline at end of file +pub use hash::{sha256, sha256_hmac, sha1, md5, crc32, derive_middleproxy_keys, build_middleproxy_prekey}; +pub use random::SecureRandom; diff --git a/src/main.rs b/src/main.rs index 280ce25..44e0815 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -//! Telemt - MTProxy on Rust +//! telemt — Telegram MTProto Proxy use std::net::SocketAddr; use std::sync::Arc; @@ -6,8 +6,8 @@ use std::time::Duration; use tokio::net::TcpListener; use tokio::signal; use tokio::sync::Semaphore; -use tracing::{info, error, warn, debug}; -use tracing_subscriber::{fmt, EnvFilter, reload, prelude::*}; +use tracing::{debug, error, info, warn}; +use tracing_subscriber::{EnvFilter, fmt, prelude::*, reload}; mod cli; mod config; @@ -20,21 +20,22 @@ mod stream; mod transport; mod util; -use crate::config::{ProxyConfig, LogLevel}; -use crate::proxy::ClientHandler; -use crate::stats::{Stats, ReplayChecker}; +use crate::config::{LogLevel, ProxyConfig}; use crate::crypto::SecureRandom; -use crate::transport::{create_listener, ListenOptions, UpstreamManager}; -use crate::util::ip::detect_ip; +use crate::proxy::ClientHandler; +use crate::stats::{ReplayChecker, Stats}; use crate::stream::BufferPool; +use crate::transport::middle_proxy::MePool; +use crate::transport::{ListenOptions, UpstreamManager, create_listener}; +use crate::util::ip::detect_ip; fn parse_cli() -> (String, bool, Option) { let mut config_path = "config.toml".to_string(); let mut silent = false; let mut log_level: Option = None; - + let args: Vec = std::env::args().skip(1).collect(); - + // Check for --init first (handled before tokio) if let Some(init_opts) = cli::parse_init_args(&args) { if let Err(e) = cli::run_init(init_opts) { @@ -43,14 +44,18 @@ fn parse_cli() -> (String, bool, Option) { } std::process::exit(0); } - + let mut i = 0; while i < args.len() { match args[i].as_str() { - "--silent" | "-s" => { silent = true; } + "--silent" | "-s" => { + silent = true; + } "--log-level" => { i += 1; - if i < args.len() { log_level = Some(args[i].clone()); } + if i < args.len() { + log_level = Some(args[i].clone()); + } } s if s.starts_with("--log-level=") => { log_level = Some(s.trim_start_matches("--log-level=").to_string()); @@ -64,26 +69,36 @@ fn parse_cli() -> (String, bool, Option) { eprintln!(" --help, -h Show this help"); eprintln!(); eprintln!("Setup (fire-and-forget):"); - eprintln!(" --init Generate config, install systemd service, start"); + eprintln!( + " --init Generate config, install systemd service, start" + ); eprintln!(" --port Listen port (default: 443)"); - eprintln!(" --domain TLS domain for masking (default: www.google.com)"); - eprintln!(" --secret 32-char hex secret (auto-generated if omitted)"); + eprintln!( + " --domain TLS domain for masking (default: www.google.com)" + ); + eprintln!( + " --secret 32-char hex secret (auto-generated if omitted)" + ); eprintln!(" --user Username (default: user)"); eprintln!(" --config-dir Config directory (default: /etc/telemt)"); eprintln!(" --no-start Don't start the service after install"); std::process::exit(0); } - s if !s.starts_with('-') => { config_path = s.to_string(); } - other => { eprintln!("Unknown option: {}", other); } + s if !s.starts_with('-') => { + config_path = s.to_string(); + } + other => { + eprintln!("Unknown option: {}", other); + } } i += 1; } - + (config_path, silent, log_level) } #[tokio::main] -async fn main() -> Result<(), Box> { +async fn main() -> std::result::Result<(), Box> { let (config_path, cli_silent, cli_log_level) = parse_cli(); let config = match ProxyConfig::load(&config_path) { @@ -100,7 +115,7 @@ async fn main() -> Result<(), Box> { } } }; - + if let Err(e) = config.validate() { eprintln!("[telemt] Invalid config: {}", e); std::process::exit(1); @@ -115,8 +130,6 @@ async fn main() -> Result<(), Box> { config.general.log_level.clone() }; - // Start with INFO so startup messages are always visible, - // then switch to user-configured level after startup let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new("info")); tracing_subscriber::registry() .with(filter_layer) @@ -125,90 +138,252 @@ async fn main() -> Result<(), Box> { info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION")); info!("Log level: {}", effective_log_level); - info!("Modes: classic={} secure={} tls={}", - config.general.modes.classic, - config.general.modes.secure, - config.general.modes.tls); + info!( + "Modes: classic={} secure={} tls={}", + config.general.modes.classic, config.general.modes.secure, config.general.modes.tls + ); info!("TLS domain: {}", config.censorship.tls_domain); if let Some(ref sock) = config.censorship.mask_unix_sock { info!("Mask: {} -> unix:{}", config.censorship.mask, sock); if !std::path::Path::new(sock).exists() { - warn!("Unix socket '{}' does not exist yet. Masking will fail until it appears.", sock); + warn!( + "Unix socket '{}' does not exist yet. Masking will fail until it appears.", + sock + ); } } else { - info!("Mask: {} -> {}:{}", + info!( + "Mask: {} -> {}:{}", config.censorship.mask, - config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain), - config.censorship.mask_port); + config + .censorship + .mask_host + .as_deref() + .unwrap_or(&config.censorship.tls_domain), + config.censorship.mask_port + ); } if config.censorship.tls_domain == "www.google.com" { warn!("Using default tls_domain. Consider setting a custom domain."); } - + let prefer_ipv6 = config.general.prefer_ipv6; + let use_middle_proxy = config.general.use_middle_proxy; let config = Arc::new(config); let stats = Arc::new(Stats::new()); let rng = Arc::new(SecureRandom::new()); - + let replay_checker = Arc::new(ReplayChecker::new( config.access.replay_check_len, Duration::from_secs(config.access.replay_window_secs), )); - + let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone())); let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096)); - - // Connection concurrency limit — prevents OOM under SYN flood / connection storm. - // 10000 is generous; each connection uses ~64KB (2x 16KB relay buffers + overhead). - // 10000 connections ≈ 640MB peak memory. - let max_connections = Arc::new(Semaphore::new(10_000)); - - // Startup DC ping - info!("=== Telegram DC Connectivity ==="); - let ping_results = upstream_manager.ping_all_dcs(prefer_ipv6).await; - for upstream_result in &ping_results { - info!(" via {}", upstream_result.upstream_name); - for dc in &upstream_result.results { - match (&dc.rtt_ms, &dc.error) { - (Some(rtt), _) => { - info!(" DC{} ({:>21}): {:.0}ms", dc.dc_idx, dc.dc_addr, rtt); - } - (None, Some(err)) => { - info!(" DC{} ({:>21}): FAIL ({})", dc.dc_idx, dc.dc_addr, err); - } - _ => { - info!(" DC{} ({:>21}): FAIL", dc.dc_idx, dc.dc_addr); + + // Connection concurrency limit + let _max_connections = Arc::new(Semaphore::new(10_000)); + + // ===================================================================== + // Middle Proxy initialization (if enabled) + // ===================================================================== + let me_pool: Option> = if use_middle_proxy { + info!("=== Middle Proxy Mode ==="); + + // ad_tag (proxy_tag) for advertising + let proxy_tag = config.general.ad_tag.as_ref().map(|tag| { + hex::decode(tag).unwrap_or_else(|_| { + warn!("Invalid ad_tag hex, middle proxy ad_tag will be empty"); + Vec::new() + }) + }); + + // ============================================================= + // CRITICAL: Download Telegram proxy-secret (NOT user secret!) + // + // C MTProxy uses TWO separate secrets: + // -S flag = 16-byte user secret for client obfuscation + // --aes-pwd = 32-512 byte binary file for ME RPC auth + // + // proxy-secret is from: https://core.telegram.org/getProxySecret + // ============================================================= + let proxy_secret_path = config.general.proxy_secret_path.as_deref(); + match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).await { + Ok(proxy_secret) => { + info!( + secret_len = proxy_secret.len(), + key_sig = format_args!( + "0x{:08x}", + if proxy_secret.len() >= 4 { + u32::from_le_bytes([ + proxy_secret[0], + proxy_secret[1], + proxy_secret[2], + proxy_secret[3], + ]) + } else { + 0 + } + ), + "Proxy-secret loaded" + ); + + let pool = MePool::new( + proxy_tag, + proxy_secret, + config.general.middle_proxy_nat_ip, + config.general.middle_proxy_nat_probe, + config.general.middle_proxy_nat_stun.clone(), + ); + + match pool.init(2, &rng).await { + Ok(()) => { + info!("Middle-End pool initialized successfully"); + + // Phase 4: Start health monitor + let pool_clone = pool.clone(); + let rng_clone = rng.clone(); + tokio::spawn(async move { + crate::transport::middle_proxy::me_health_monitor( + pool_clone, rng_clone, 2, + ) + .await; + }); + + Some(pool) + } + Err(e) => { + error!(error = %e, "Failed to initialize ME pool. Falling back to direct mode."); + None + } } } + Err(e) => { + error!(error = %e, "Failed to fetch proxy-secret. Falling back to direct mode."); + None + } + } + } else { + None + }; + + if me_pool.is_some() { + info!("Transport: Middle Proxy (supports all DCs including CDN)"); + } else { + info!("Transport: Direct TCP (standard DCs only)"); + } + + // Startup DC ping (only meaningful in direct mode) + if me_pool.is_none() { + info!("================= Telegram DC Connectivity ================="); + + let ping_results = upstream_manager.ping_all_dcs(prefer_ipv6).await; + + for upstream_result in &ping_results { + // Show which IP version is in use and which is fallback + if upstream_result.both_available { + if prefer_ipv6 { + info!(" IPv6 in use and IPv4 is fallback"); + } else { + info!(" IPv4 in use and IPv6 is fallback"); + } + } else { + let v6_works = upstream_result + .v6_results + .iter() + .any(|r| r.rtt_ms.is_some()); + let v4_works = upstream_result + .v4_results + .iter() + .any(|r| r.rtt_ms.is_some()); + if v6_works && !v4_works { + info!(" IPv6 only (IPv4 unavailable)"); + } else if v4_works && !v6_works { + info!(" IPv4 only (IPv6 unavailable)"); + } else if !v6_works && !v4_works { + info!(" No connectivity!"); + } + } + + info!(" via {}", upstream_result.upstream_name); + info!("============================================================"); + + // Print IPv6 results first + for dc in &upstream_result.v6_results { + let addr_str = format!("{}:{}", dc.dc_addr.ip(), dc.dc_addr.port()); + match &dc.rtt_ms { + Some(rtt) => { + // Align: IPv6 addresses are longer, use fewer tabs + // [2001:b28:f23d:f001::a]:443 = ~28 chars + info!(" DC{} [IPv6] {}:\t\t{:.0} ms", dc.dc_idx, addr_str, rtt); + } + None => { + let err = dc.error.as_deref().unwrap_or("fail"); + info!(" DC{} [IPv6] {}:\t\tFAIL ({})", dc.dc_idx, addr_str, err); + } + } + } + + info!("============================================================"); + + // Print IPv4 results + for dc in &upstream_result.v4_results { + let addr_str = format!("{}:{}", dc.dc_addr.ip(), dc.dc_addr.port()); + match &dc.rtt_ms { + Some(rtt) => { + // Align: IPv4 addresses are shorter, use more tabs + // 149.154.175.50:443 = ~18 chars + info!( + " DC{} [IPv4] {}:\t\t\t\t{:.0} ms", + dc.dc_idx, addr_str, rtt + ); + } + None => { + let err = dc.error.as_deref().unwrap_or("fail"); + info!( + " DC{} [IPv4] {}:\t\t\t\tFAIL ({})", + dc.dc_idx, addr_str, err + ); + } + } + } + + info!("============================================================"); } } - info!("================================"); - + // Background tasks let um_clone = upstream_manager.clone(); - tokio::spawn(async move { um_clone.run_health_checks(prefer_ipv6).await; }); - + tokio::spawn(async move { + um_clone.run_health_checks(prefer_ipv6).await; + }); + let rc_clone = replay_checker.clone(); - tokio::spawn(async move { rc_clone.run_periodic_cleanup().await; }); + tokio::spawn(async move { + rc_clone.run_periodic_cleanup().await; + }); let detected_ip = detect_ip().await; - debug!("Detected IPs: v4={:?} v6={:?}", detected_ip.ipv4, detected_ip.ipv6); + debug!( + "Detected IPs: v4={:?} v6={:?}", + detected_ip.ipv4, detected_ip.ipv6 + ); let mut listeners = Vec::new(); - + for listener_conf in &config.server.listeners { let addr = SocketAddr::new(listener_conf.ip, config.server.port); let options = ListenOptions { ipv6_only: listener_conf.ip.is_ipv6(), ..Default::default() }; - + match create_listener(addr, &options) { Ok(socket) => { let listener = TcpListener::from_std(socket.into())?; info!("Listening on {}", addr); - + let public_ip = if let Some(ip) = listener_conf.announce_ip { ip } else if listener_conf.ip.is_unspecified() { @@ -227,17 +402,23 @@ async fn main() -> Result<(), Box> { if let Some(secret) = config.access.users.get(user_name) { info!("User: {}", user_name); if config.general.modes.classic { - info!(" Classic: tg://proxy?server={}&port={}&secret={}", - public_ip, config.server.port, secret); + info!( + " Classic: tg://proxy?server={}&port={}&secret={}", + public_ip, config.server.port, secret + ); } if config.general.modes.secure { - info!(" DD: tg://proxy?server={}&port={}&secret=dd{}", - public_ip, config.server.port, secret); + info!( + " DD: tg://proxy?server={}&port={}&secret=dd{}", + public_ip, config.server.port, secret + ); } if config.general.modes.tls { let domain_hex = hex::encode(&config.censorship.tls_domain); - info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", - public_ip, config.server.port, secret, domain_hex); + info!( + " EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", + public_ip, config.server.port, secret, domain_hex + ); } } else { warn!("User '{}' in show_link not found", user_name); @@ -245,15 +426,15 @@ async fn main() -> Result<(), Box> { } info!("------------------------"); } - + listeners.push(listener); - }, + } Err(e) => { error!("Failed to bind to {}: {}", addr, e); } } } - + if listeners.is_empty() { error!("No listeners. Exiting."); std::process::exit(1); @@ -265,7 +446,9 @@ async fn main() -> Result<(), Box> { } else { EnvFilter::new(effective_log_level.to_filter_str()) }; - filter_handle.reload(runtime_filter).expect("Failed to switch log filter"); + filter_handle + .reload(runtime_filter) + .expect("Failed to switch log filter"); for listener in listeners { let config = config.clone(); @@ -274,7 +457,8 @@ async fn main() -> Result<(), Box> { let replay_checker = replay_checker.clone(); let buffer_pool = buffer_pool.clone(); let rng = rng.clone(); - + let me_pool = me_pool.clone(); + tokio::spawn(async move { loop { match listener.accept().await { @@ -285,12 +469,23 @@ async fn main() -> Result<(), Box> { let replay_checker = replay_checker.clone(); let buffer_pool = buffer_pool.clone(); let rng = rng.clone(); - + let me_pool = me_pool.clone(); + tokio::spawn(async move { if let Err(e) = ClientHandler::new( - stream, peer_addr, config, stats, - upstream_manager, replay_checker, buffer_pool, rng - ).run().await { + stream, + peer_addr, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + me_pool, + ) + .run() + .await + { debug!(peer = %peer_addr, error = %e, "Connection error"); } }); @@ -310,4 +505,4 @@ async fn main() -> Result<(), Box> { } Ok(()) -} \ No newline at end of file +} diff --git a/src/protocol/constants.rs b/src/protocol/constants.rs index 7451c83..86cd2bd 100644 --- a/src/protocol/constants.rs +++ b/src/protocol/constants.rs @@ -202,6 +202,17 @@ pub static RESERVED_NONCE_CONTINUES: &[[u8; 4]] = &[ // ============= RPC Constants (for Middle Proxy) ============= /// RPC Proxy Request + +/// RPC Flags (from Erlang mtp_rpc.erl) +pub const RPC_FLAG_NOT_ENCRYPTED: u32 = 0x2; +pub const RPC_FLAG_HAS_AD_TAG: u32 = 0x8; +pub const RPC_FLAG_MAGIC: u32 = 0x1000; +pub const RPC_FLAG_EXTMODE2: u32 = 0x20000; +pub const RPC_FLAG_PAD: u32 = 0x8000000; +pub const RPC_FLAG_INTERMEDIATE: u32 = 0x20000000; +pub const RPC_FLAG_ABRIDGED: u32 = 0x40000000; +pub const RPC_FLAG_QUICKACK: u32 = 0x80000000; + pub const RPC_PROXY_REQ: [u8; 4] = [0xee, 0xf1, 0xce, 0x36]; /// RPC Proxy Answer pub const RPC_PROXY_ANS: [u8; 4] = [0x0d, 0xda, 0x03, 0x44]; @@ -228,7 +239,56 @@ pub mod rpc_flags { pub const FLAG_QUICKACK: u32 = 0x80000000; } -#[cfg(test)] + + // ============= Middle-End Proxy Servers ============= + pub const ME_PROXY_PORT: u16 = 8888; + + pub static TG_MIDDLE_PROXIES_FLAT_V4: LazyLock> = LazyLock::new(|| { + vec![ + (IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888), + (IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888), + (IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888), + (IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888), + (IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888), + ] + }); + + // ============= RPC Constants (u32 native endian) ============= + // From mtproto-common.h + net-tcp-rpc-common.h + mtproto-proxy.c + + pub const RPC_NONCE_U32: u32 = 0x7acb87aa; + pub const RPC_HANDSHAKE_U32: u32 = 0x7682eef5; + pub const RPC_HANDSHAKE_ERROR_U32: u32 = 0x6a27beda; + pub const TL_PROXY_TAG_U32: u32 = 0xdb1e26ae; // mtproto-proxy.c:121 + + // mtproto-common.h + pub const RPC_PROXY_REQ_U32: u32 = 0x36cef1ee; + pub const RPC_PROXY_ANS_U32: u32 = 0x4403da0d; + pub const RPC_CLOSE_CONN_U32: u32 = 0x1fcf425d; + pub const RPC_CLOSE_EXT_U32: u32 = 0x5eb634a2; + pub const RPC_SIMPLE_ACK_U32: u32 = 0x3bac409b; + pub const RPC_PING_U32: u32 = 0x5730a2df; + pub const RPC_PONG_U32: u32 = 0x8430eaa7; + + pub const RPC_CRYPTO_NONE_U32: u32 = 0; + pub const RPC_CRYPTO_AES_U32: u32 = 1; + + pub mod proxy_flags { + pub const FLAG_HAS_AD_TAG: u32 = 1; + pub const FLAG_NOT_ENCRYPTED: u32 = 0x2; + pub const FLAG_HAS_AD_TAG2: u32 = 0x8; + pub const FLAG_MAGIC: u32 = 0x1000; + pub const FLAG_EXTMODE2: u32 = 0x20000; + pub const FLAG_PAD: u32 = 0x8000000; + pub const FLAG_INTERMEDIATE: u32 = 0x20000000; + pub const FLAG_ABRIDGED: u32 = 0x40000000; + pub const FLAG_QUICKACK: u32 = 0x80000000; + } + + pub const ME_CONNECT_TIMEOUT_SECS: u64 = 5; + pub const ME_HANDSHAKE_TIMEOUT_SECS: u64 = 10; + + #[cfg(test)] mod tests { use super::*; diff --git a/src/proxy/client.rs b/src/proxy/client.rs index adcb25b..726d238 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -3,26 +3,25 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::net::TcpStream; -use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; use tokio::time::timeout; -use tracing::{debug, info, warn, error, trace}; +use tracing::{debug, warn}; use crate::config::ProxyConfig; -use crate::error::{ProxyError, Result, HandshakeResult}; +use crate::crypto::SecureRandom; +use crate::error::{HandshakeResult, ProxyError, Result}; use crate::protocol::constants::*; use crate::protocol::tls; -use crate::stats::{Stats, ReplayChecker}; -use crate::transport::{configure_client_socket, UpstreamManager}; -use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter, BufferPool}; -use crate::crypto::{AesCtr, SecureRandom}; +use crate::stats::{ReplayChecker, Stats}; +use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; +use crate::transport::middle_proxy::MePool; +use crate::transport::{UpstreamManager, configure_client_socket}; -use crate::proxy::handshake::{ - handle_tls_handshake, handle_mtproto_handshake, - HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce, -}; -use crate::proxy::relay::relay_bidirectional; +use crate::proxy::direct_relay::handle_via_direct; +use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake}; use crate::proxy::masking::handle_bad_client; +use crate::proxy::middle_relay::handle_via_middle_proxy; pub struct ClientHandler; @@ -35,6 +34,7 @@ pub struct RunningClientHandler { upstream_manager: Arc, buffer_pool: Arc, rng: Arc, + me_pool: Option>, } impl ClientHandler { @@ -47,10 +47,18 @@ impl ClientHandler { replay_checker: Arc, buffer_pool: Arc, rng: Arc, + me_pool: Option>, ) -> RunningClientHandler { RunningClientHandler { - stream, peer, config, stats, replay_checker, - upstream_manager, buffer_pool, rng, + stream, + peer, + config, + stats, + replay_checker, + upstream_manager, + buffer_pool, + rng, + me_pool, } } } @@ -58,10 +66,10 @@ impl ClientHandler { impl RunningClientHandler { pub async fn run(mut self) -> Result<()> { self.stats.increment_connects_all(); - + let peer = self.peer; debug!(peer = %peer, "New connection"); - + if let Err(e) = configure_client_socket( &self.stream, self.config.timeouts.client_keepalive, @@ -69,12 +77,12 @@ impl RunningClientHandler { ) { debug!(peer = %peer, error = %e, "Failed to configure client socket"); } - + let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake); let stats = self.stats.clone(); - + let result = timeout(handshake_timeout, self.do_handshake()).await; - + match result { Ok(Ok(())) => { debug!(peer = %peer, "Connection handled successfully"); @@ -91,30 +99,30 @@ impl RunningClientHandler { } } } - + async fn do_handshake(mut self) -> Result<()> { let mut first_bytes = [0u8; 5]; self.stream.read_exact(&mut first_bytes).await?; - + let is_tls = tls::is_tls_handshake(&first_bytes[..3]); let peer = self.peer; - + debug!(peer = %peer, is_tls = is_tls, "Handshake type detected"); - + if is_tls { self.handle_tls_client(first_bytes).await } else { self.handle_direct_client(first_bytes).await } } - + async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> { let peer = self.peer; - + let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; - + debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake"); - + if tls_len < 512 { debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); self.stats.increment_connects_bad(); @@ -122,22 +130,30 @@ impl RunningClientHandler { handle_bad_client(reader, writer, &first_bytes, &self.config).await; return Ok(()); } - + let mut handshake = vec![0u8; 5 + tls_len]; handshake[..5].copy_from_slice(&first_bytes); self.stream.read_exact(&mut handshake[5..]).await?; - + let config = self.config.clone(); let replay_checker = self.replay_checker.clone(); let stats = self.stats.clone(); let buffer_pool = self.buffer_pool.clone(); - + + let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; let (read_half, write_half) = self.stream.into_split(); - + let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( - &handshake, read_half, write_half, peer, - &config, &replay_checker, &self.rng, - ).await { + &handshake, + read_half, + write_half, + peer, + &config, + &replay_checker, + &self.rng, + ) + .await + { HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); @@ -146,35 +162,54 @@ impl RunningClientHandler { } HandshakeResult::Error(e) => return Err(e), }; - + debug!(peer = %peer, "Reading MTProto handshake through TLS"); let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?; - let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into() + let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..] + .try_into() .map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?; - + let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( - &mtproto_handshake, tls_reader, tls_writer, peer, - &config, &replay_checker, true, - ).await { + &mtproto_handshake, + tls_reader, + tls_writer, + peer, + &config, + &replay_checker, + true, + ) + .await + { HandshakeResult::Success(result) => result, - HandshakeResult::BadClient { reader: _, writer: _ } => { + HandshakeResult::BadClient { + reader: _, + writer: _, + } => { stats.increment_connects_bad(); debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); return Ok(()); } HandshakeResult::Error(e) => return Err(e), }; - + Self::handle_authenticated_static( - crypto_reader, crypto_writer, success, - self.upstream_manager, self.stats, self.config, - buffer_pool, self.rng, - ).await + crypto_reader, + crypto_writer, + success, + self.upstream_manager, + self.stats, + self.config, + buffer_pool, + self.rng, + self.me_pool, + local_addr, + ) + .await } - + async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> { let peer = self.peer; - + if !self.config.general.modes.classic && !self.config.general.modes.secure { debug!(peer = %peer, "Non-TLS modes disabled"); self.stats.increment_connects_bad(); @@ -182,22 +217,30 @@ impl RunningClientHandler { handle_bad_client(reader, writer, &first_bytes, &self.config).await; return Ok(()); } - + let mut handshake = [0u8; HANDSHAKE_LEN]; handshake[..5].copy_from_slice(&first_bytes); self.stream.read_exact(&mut handshake[5..]).await?; - + let config = self.config.clone(); let replay_checker = self.replay_checker.clone(); let stats = self.stats.clone(); let buffer_pool = self.buffer_pool.clone(); - + + let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; let (read_half, write_half) = self.stream.into_split(); - + let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( - &handshake, read_half, write_half, peer, - &config, &replay_checker, false, - ).await { + &handshake, + read_half, + write_half, + peer, + &config, + &replay_checker, + false, + ) + .await + { HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); @@ -206,14 +249,26 @@ impl RunningClientHandler { } HandshakeResult::Error(e) => return Err(e), }; - + Self::handle_authenticated_static( - crypto_reader, crypto_writer, success, - self.upstream_manager, self.stats, self.config, - buffer_pool, self.rng, - ).await + crypto_reader, + crypto_writer, + success, + self.upstream_manager, + self.stats, + self.config, + buffer_pool, + self.rng, + self.me_pool, + local_addr, + ) + .await } - + + /// Main dispatch after successful handshake. + /// Two modes: + /// - Direct: TCP relay to TG DC (existing behavior) + /// - Middle Proxy: RPC multiplex through ME pool (new — supports CDN DCs) async fn handle_authenticated_static( client_reader: CryptoReader, client_writer: CryptoWriter, @@ -223,180 +278,77 @@ impl RunningClientHandler { config: Arc, buffer_pool: Arc, rng: Arc, + me_pool: Option>, + local_addr: SocketAddr, ) -> Result<()> where R: AsyncRead + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static, { let user = &success.user; - + if let Err(e) = Self::check_user_limits_static(user, &config, &stats) { warn!(user = %user, error = %e, "User limit exceeded"); return Err(e); } - - let dc_addr = Self::get_dc_addr_static(success.dc_idx, &config)?; - - info!( - user = %user, - peer = %success.peer, - dc = success.dc_idx, - dc_addr = %dc_addr, - proto = ?success.proto_tag, - "Connecting to Telegram" - ); - - // Pass dc_idx for latency-based upstream selection - let tg_stream = upstream_manager.connect(dc_addr, Some(success.dc_idx)).await?; - - debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake"); - - let (tg_reader, tg_writer) = Self::do_tg_handshake_static( - tg_stream, &success, &config, rng.as_ref(), - ).await?; - - debug!(peer = %success.peer, "TG handshake complete, starting relay"); - - stats.increment_user_connects(user); - stats.increment_user_curr_connects(user); - - let relay_result = relay_bidirectional( - client_reader, client_writer, - tg_reader, tg_writer, - user, Arc::clone(&stats), buffer_pool, - ).await; - - stats.decrement_user_curr_connects(user); - - match &relay_result { - Ok(()) => debug!(user = %user, "Relay completed"), - Err(e) => debug!(user = %user, error = %e, "Relay ended with error"), + + // Decide: middle proxy or direct + if config.general.use_middle_proxy { + if let Some(ref pool) = me_pool { + return handle_via_middle_proxy( + client_reader, + client_writer, + success, + pool.clone(), + stats, + config, + buffer_pool, + local_addr, + ) + .await; + } + warn!("use_middle_proxy=true but MePool not initialized, falling back to direct"); } - - relay_result + + // Direct mode (original behavior) + handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager, + stats, + config, + buffer_pool, + rng, + ) + .await } - + fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> { if let Some(expiration) = config.access.user_expirations.get(user) { if chrono::Utc::now() > *expiration { - return Err(ProxyError::UserExpired { user: user.to_string() }); + return Err(ProxyError::UserExpired { + user: user.to_string(), + }); } } - + if let Some(limit) = config.access.user_max_tcp_conns.get(user) { if stats.get_user_curr_connects(user) >= *limit as u64 { - return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() }); + return Err(ProxyError::ConnectionLimitExceeded { + user: user.to_string(), + }); } } - + if let Some(quota) = config.access.user_data_quota.get(user) { if stats.get_user_total_octets(user) >= *quota { - return Err(ProxyError::DataQuotaExceeded { user: user.to_string() }); + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); } } - + Ok(()) } - - /// Resolve DC index to a target address. - /// - /// Matches the C implementation's behavior exactly: - /// - /// 1. Look up DC in known clusters (standard DCs ±1..±5) - /// 2. If not found and `force=1` → fall back to `default_cluster` - /// - /// In the C code: - /// - `proxy-multi.conf` is downloaded from Telegram, contains only DC ±1..±5 - /// - `default 2;` directive sets the default cluster - /// - `mf_cluster_lookup(CurConf, target_dc, 1)` returns default_cluster - /// for any unknown DC (like CDN DC 203) - /// - /// So DC 203, DC 101, DC -300, etc. all route to the default DC (2). - /// There is NO modular arithmetic in the C implementation. - fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { - let datacenters = if config.general.prefer_ipv6 { - &*TG_DATACENTERS_V6 - } else { - &*TG_DATACENTERS_V4 - }; - - let num_dcs = datacenters.len(); // 5 - - // === Step 1: Check dc_overrides (like C's `proxy_for :`) === - let dc_key = dc_idx.to_string(); - if let Some(addr_str) = config.dc_overrides.get(&dc_key) { - match addr_str.parse::() { - Ok(addr) => { - debug!(dc_idx = dc_idx, addr = %addr, "Using DC override from config"); - return Ok(addr); - } - Err(_) => { - warn!(dc_idx = dc_idx, addr_str = %addr_str, - "Invalid DC override address in config, ignoring"); - } - } - } - - // === Step 2: Standard DCs ±1..±5 — direct lookup === - let abs_dc = dc_idx.unsigned_abs() as usize; - if abs_dc >= 1 && abs_dc <= num_dcs { - return Ok(SocketAddr::new(datacenters[abs_dc - 1], TG_DATACENTER_PORT)); - } - - // === Step 3: Unknown DC — fall back to default_cluster === - // Exactly like C's `mf_cluster_lookup(CurConf, target_dc, force=1)` - // which returns `MC->default_cluster` when the DC is not found. - // Telegram's proxy-multi.conf uses `default 2;` - let default_dc = config.default_dc.unwrap_or(2) as usize; - let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs { - default_dc - 1 - } else { - 1 // DC 2 (index 1) — matches Telegram's `default 2;` - }; - - info!( - original_dc = dc_idx, - fallback_dc = (fallback_idx + 1) as u16, - fallback_addr = %datacenters[fallback_idx], - "Special DC ---> default_cluster" - ); - - Ok(SocketAddr::new(datacenters[fallback_idx], TG_DATACENTER_PORT)) - } - - async fn do_tg_handshake_static( - mut stream: TcpStream, - success: &HandshakeSuccess, - config: &ProxyConfig, - rng: &SecureRandom, - ) -> Result<(CryptoReader, CryptoWriter)> { - let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce( - success.proto_tag, - &success.dec_key, - success.dec_iv, - rng, - config.general.fast_mode, - ); - - let encrypted_nonce = encrypt_tg_nonce(&nonce); - - debug!( - peer = %success.peer, - nonce_head = %hex::encode(&nonce[..16]), - "Sending nonce to Telegram" - ); - - stream.write_all(&encrypted_nonce).await?; - stream.flush().await?; - - let (read_half, write_half) = stream.into_split(); - - let decryptor = AesCtr::new(&tg_dec_key, tg_dec_iv); - let encryptor = AesCtr::new(&tg_enc_key, tg_enc_iv); - - Ok(( - CryptoReader::new(read_half, decryptor), - CryptoWriter::new(write_half, encryptor), - )) - } -} \ No newline at end of file +} diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs new file mode 100644 index 0000000..46c004c --- /dev/null +++ b/src/proxy/direct_relay.rs @@ -0,0 +1,163 @@ +use std::net::SocketAddr; +use std::sync::Arc; + +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::net::TcpStream; +use tracing::{debug, info, warn}; + +use crate::config::ProxyConfig; +use crate::crypto::SecureRandom; +use crate::error::Result; +use crate::protocol::constants::*; +use crate::proxy::handshake::{HandshakeSuccess, encrypt_tg_nonce_with_ciphers, generate_tg_nonce}; +use crate::proxy::relay::relay_bidirectional; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; +use crate::transport::UpstreamManager; + +pub(crate) async fn handle_via_direct( + client_reader: CryptoReader, + client_writer: CryptoWriter, + success: HandshakeSuccess, + upstream_manager: Arc, + stats: Arc, + config: Arc, + buffer_pool: Arc, + rng: Arc, +) -> Result<()> +where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, +{ + let user = &success.user; + let dc_addr = get_dc_addr_static(success.dc_idx, &config)?; + + info!( + user = %user, + peer = %success.peer, + dc = success.dc_idx, + dc_addr = %dc_addr, + proto = ?success.proto_tag, + mode = "direct", + "Connecting to Telegram DC" + ); + + let tg_stream = upstream_manager + .connect(dc_addr, Some(success.dc_idx)) + .await?; + + debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake"); + + let (tg_reader, tg_writer) = + do_tg_handshake_static(tg_stream, &success, &config, rng.as_ref()).await?; + + debug!(peer = %success.peer, "TG handshake complete, starting relay"); + + stats.increment_user_connects(user); + stats.increment_user_curr_connects(user); + + let relay_result = relay_bidirectional( + client_reader, + client_writer, + tg_reader, + tg_writer, + user, + Arc::clone(&stats), + buffer_pool, + ) + .await; + + stats.decrement_user_curr_connects(user); + + match &relay_result { + Ok(()) => debug!(user = %user, "Direct relay completed"), + Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"), + } + + relay_result +} + +fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { + let datacenters = if config.general.prefer_ipv6 { + &*TG_DATACENTERS_V6 + } else { + &*TG_DATACENTERS_V4 + }; + + let num_dcs = datacenters.len(); + + let dc_key = dc_idx.to_string(); + if let Some(addr_str) = config.dc_overrides.get(&dc_key) { + match addr_str.parse::() { + Ok(addr) => { + debug!(dc_idx = dc_idx, addr = %addr, "Using DC override from config"); + return Ok(addr); + } + Err(_) => { + warn!(dc_idx = dc_idx, addr_str = %addr_str, + "Invalid DC override address in config, ignoring"); + } + } + } + + let abs_dc = dc_idx.unsigned_abs() as usize; + if abs_dc >= 1 && abs_dc <= num_dcs { + return Ok(SocketAddr::new(datacenters[abs_dc - 1], TG_DATACENTER_PORT)); + } + + let default_dc = config.default_dc.unwrap_or(2) as usize; + let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs { + default_dc - 1 + } else { + 1 + }; + + info!( + original_dc = dc_idx, + fallback_dc = (fallback_idx + 1) as u16, + fallback_addr = %datacenters[fallback_idx], + "Special DC ---> default_cluster" + ); + + Ok(SocketAddr::new( + datacenters[fallback_idx], + TG_DATACENTER_PORT, + )) +} + +async fn do_tg_handshake_static( + mut stream: TcpStream, + success: &HandshakeSuccess, + config: &ProxyConfig, + rng: &SecureRandom, +) -> Result<( + CryptoReader, + CryptoWriter, +)> { + let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce( + success.proto_tag, + success.dc_idx, + &success.dec_key, + success.dec_iv, + rng, + config.general.fast_mode, + ); + + let (encrypted_nonce, tg_encryptor, tg_decryptor) = encrypt_tg_nonce_with_ciphers(&nonce); + + debug!( + peer = %success.peer, + nonce_head = %hex::encode(&nonce[..16]), + "Sending nonce to Telegram" + ); + + stream.write_all(&encrypted_nonce).await?; + stream.flush().await?; + + let (read_half, write_half) = stream.into_split(); + + Ok(( + CryptoReader::new(read_half, tg_decryptor), + CryptoWriter::new(write_half, tg_encryptor), + )) +} diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 81814e2..ab8e70c 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -61,26 +61,26 @@ where W: AsyncWrite + Unpin, { debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake"); - + if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { debug!(peer = %peer, "TLS handshake too short"); return HandshakeResult::BadClient { reader, writer }; } - + let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]; let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN]; - + if replay_checker.check_tls_digest(digest_half) { warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); return HandshakeResult::BadClient { reader, writer }; } - + let secrets: Vec<(String, Vec)> = config.access.users.iter() .filter_map(|(name, hex)| { hex::decode(hex).ok().map(|bytes| (name.clone(), bytes)) }) .collect(); - + let validation = match tls::validate_tls_handshake( handshake, &secrets, @@ -96,12 +96,12 @@ where return HandshakeResult::BadClient { reader, writer }; } }; - + let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { Some((_, s)) => s, None => return HandshakeResult::BadClient { reader, writer }, }; - + let response = tls::build_server_hello( secret, &validation.digest, @@ -109,27 +109,27 @@ where config.censorship.fake_cert_len, rng, ); - + debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello"); - + if let Err(e) = writer.write_all(&response).await { warn!(peer = %peer, error = %e, "Failed to write TLS ServerHello"); return HandshakeResult::Error(ProxyError::Io(e)); } - + if let Err(e) = writer.flush().await { warn!(peer = %peer, error = %e, "Failed to flush TLS ServerHello"); return HandshakeResult::Error(ProxyError::Io(e)); } - + replay_checker.add_tls_digest(digest_half); - + info!( peer = %peer, user = %validation.user, "TLS handshake successful" ); - + HandshakeResult::Success(( FakeTlsReader::new(reader), FakeTlsWriter::new(writer), @@ -152,75 +152,74 @@ where W: AsyncWrite + Unpin + Send, { trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes"); - + let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; - + if replay_checker.check_handshake(dec_prekey_iv) { warn!(peer = %peer, "MTProto replay attack detected"); return HandshakeResult::BadClient { reader, writer }; } - + let enc_prekey_iv: Vec = dec_prekey_iv.iter().rev().copied().collect(); - + for (user, secret_hex) in &config.access.users { let secret = match hex::decode(secret_hex) { Ok(s) => s, Err(_) => continue, }; - + let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; - + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); dec_key_input.extend_from_slice(dec_prekey); dec_key_input.extend_from_slice(&secret); let dec_key = sha256(&dec_key_input); - + let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap()); - + let mut decryptor = AesCtr::new(&dec_key, dec_iv); let decrypted = decryptor.decrypt(handshake); - + let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] .try_into() .unwrap(); - + let proto_tag = match ProtoTag::from_bytes(tag_bytes) { Some(tag) => tag, None => continue, }; - + let mode_ok = match proto_tag { ProtoTag::Secure => { if is_tls { config.general.modes.tls } else { config.general.modes.secure } } ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic, }; - + if !mode_ok { debug!(peer = %peer, user = %user, proto = ?proto_tag, "Mode not enabled"); continue; } - + let dc_idx = i16::from_le_bytes( decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap() ); - + let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; - + let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); enc_key_input.extend_from_slice(enc_prekey); enc_key_input.extend_from_slice(&secret); let enc_key = sha256(&enc_key_input); - + let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap()); - + replay_checker.add_handshake(dec_prekey_iv); - - let decryptor = AesCtr::new(&dec_key, dec_iv); + let encryptor = AesCtr::new(&enc_key, enc_iv); - + let success = HandshakeSuccess { user: user.clone(), dc_idx, @@ -232,7 +231,7 @@ where peer, is_tls, }; - + info!( peer = %peer, user = %user, @@ -241,14 +240,14 @@ where tls = is_tls, "MTProto handshake successful" ); - + return HandshakeResult::Success(( CryptoReader::new(reader, decryptor), CryptoWriter::new(writer, encryptor), success, )); } - + debug!(peer = %peer, "MTProto handshake: no matching user found"); HandshakeResult::BadClient { reader, writer } } @@ -256,6 +255,7 @@ where /// Generate nonce for Telegram connection pub fn generate_tg_nonce( proto_tag: ProtoTag, + dc_idx: i16, client_dec_key: &[u8; 32], client_dec_iv: u128, rng: &SecureRandom, @@ -264,86 +264,101 @@ pub fn generate_tg_nonce( loop { let bytes = rng.bytes(HANDSHAKE_LEN); let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap(); - + if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; } - + let first_four: [u8; 4] = nonce[..4].try_into().unwrap(); if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { continue; } - + let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap(); if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; } - + nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); - + // CRITICAL: write dc_idx so upstream DC knows where to route + nonce[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + if fast_mode { nonce[SKIP_LEN..SKIP_LEN + KEY_LEN].copy_from_slice(client_dec_key); nonce[SKIP_LEN + KEY_LEN..SKIP_LEN + KEY_LEN + IV_LEN] .copy_from_slice(&client_dec_iv.to_be_bytes()); } - + let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let dec_key_iv: Vec = enc_key_iv.iter().rev().copied().collect(); - + let tg_enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap(); let tg_enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap()); - + let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap(); let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap()); - + return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv); } } -/// Encrypt nonce for sending to Telegram -pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec { +/// Encrypt nonce for sending to Telegram and return cipher objects with correct counter state +pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec, AesCtr, AesCtr) { let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; - let key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap(); - let iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap()); - - let mut encryptor = AesCtr::new(&key, iv); - let encrypted_full = encryptor.encrypt(nonce); - + let dec_key_iv: Vec = enc_key_iv.iter().rev().copied().collect(); + + let enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap(); + let enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap()); + + let dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap(); + let dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap()); + + let mut encryptor = AesCtr::new(&enc_key, enc_iv); + let encrypted_full = encryptor.encrypt(nonce); // counter: 0 → 4 + let mut result = nonce[..PROTO_TAG_POS].to_vec(); result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]); - - result + + let decryptor = AesCtr::new(&dec_key, dec_iv); + + (result, encryptor, decryptor) +} + +/// Encrypt nonce for sending to Telegram (legacy function for compatibility) +pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec { + let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(nonce); + encrypted } #[cfg(test)] mod tests { use super::*; - + #[test] fn test_generate_tg_nonce() { let client_dec_key = [0x42u8; 32]; let client_dec_iv = 12345u128; - + let rng = SecureRandom::new(); let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = - generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, &rng, false); - + generate_tg_nonce(ProtoTag::Secure, 2, &client_dec_key, client_dec_iv, &rng, false); + assert_eq!(nonce.len(), HANDSHAKE_LEN); - + let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap(); assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure)); } - + #[test] fn test_encrypt_tg_nonce() { let client_dec_key = [0x42u8; 32]; let client_dec_iv = 12345u128; - + let rng = SecureRandom::new(); let (nonce, _, _, _, _) = - generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, &rng, false); - + generate_tg_nonce(ProtoTag::Secure, 2, &client_dec_key, client_dec_iv, &rng, false); + let encrypted = encrypt_tg_nonce(&nonce); - + assert_eq!(encrypted.len(), HANDSHAKE_LEN); assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]); assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]); } - + #[test] fn test_handshake_success_zeroize_on_drop() { let success = HandshakeSuccess { @@ -357,10 +372,10 @@ mod tests { peer: "127.0.0.1:1234".parse().unwrap(), is_tls: true, }; - + assert_eq!(success.dec_key, [0xAA; 32]); assert_eq!(success.enc_key, [0xCC; 32]); - + drop(success); // Drop impl zeroizes key material without panic } diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs new file mode 100644 index 0000000..0882d0e --- /dev/null +++ b/src/proxy/middle_relay.rs @@ -0,0 +1,254 @@ +use std::net::SocketAddr; +use std::sync::Arc; + +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tracing::{debug, info, trace}; + +use crate::config::ProxyConfig; +use crate::error::{ProxyError, Result}; +use crate::protocol::constants::*; +use crate::proxy::handshake::HandshakeSuccess; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; +use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; + +pub(crate) async fn handle_via_middle_proxy( + mut crypto_reader: CryptoReader, + mut crypto_writer: CryptoWriter, + success: HandshakeSuccess, + me_pool: Arc, + stats: Arc, + _config: Arc, + _buffer_pool: Arc, + local_addr: SocketAddr, +) -> Result<()> +where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, +{ + let user = success.user.clone(); + let peer = success.peer; + let proto_tag = success.proto_tag; + + info!( + user = %user, + peer = %peer, + dc = success.dc_idx, + proto = ?proto_tag, + mode = "middle_proxy", + "Routing via Middle-End" + ); + + let (conn_id, mut me_rx) = me_pool.registry().register().await; + + stats.increment_user_connects(&user); + stats.increment_user_curr_connects(&user); + + let proto_flags = proto_flags_for_tag(proto_tag, me_pool.has_proxy_tag()); + debug!( + user = %user, + conn_id, + proto_flags = format_args!("0x{:08x}", proto_flags), + "ME relay started" + ); + + let translated_local_addr = me_pool.translate_our_addr(local_addr); + + let result: Result<()> = loop { + tokio::select! { + client_frame = read_client_payload(&mut crypto_reader, proto_tag) => { + match client_frame { + Ok(Some(payload)) => { + trace!(conn_id, bytes = payload.len(), "C->ME frame"); + stats.add_user_octets_from(&user, payload.len() as u64); + me_pool.send_proxy_req( + conn_id, + success.dc_idx, + peer, + translated_local_addr, + &payload, + proto_flags, + ).await?; + } + Ok(None) => { + debug!(conn_id, "Client EOF"); + let _ = me_pool.send_close(conn_id).await; + break Ok(()); + } + Err(e) => break Err(e), + } + } + me_msg = me_rx.recv() => { + match me_msg { + Some(MeResponse::Data { flags, data }) => { + trace!(conn_id, bytes = data.len(), flags, "ME->C data"); + stats.add_user_octets_to(&user, data.len() as u64); + write_client_payload(&mut crypto_writer, proto_tag, flags, &data).await?; + } + Some(MeResponse::Ack(confirm)) => { + trace!(conn_id, confirm, "ME->C quickack"); + write_client_ack(&mut crypto_writer, proto_tag, confirm).await?; + } + Some(MeResponse::Close) => { + debug!(conn_id, "ME sent close"); + break Ok(()); + } + None => { + debug!(conn_id, "ME channel closed"); + break Err(ProxyError::Proxy("ME connection lost".into())); + } + } + } + } + }; + + debug!(user = %user, conn_id, "ME relay cleanup"); + me_pool.registry().unregister(conn_id).await; + stats.decrement_user_curr_connects(&user); + result +} + +async fn read_client_payload( + client_reader: &mut CryptoReader, + proto_tag: ProtoTag, +) -> Result>> +where + R: AsyncRead + Unpin + Send + 'static, +{ + let len = match proto_tag { + ProtoTag::Abridged => { + let mut first = [0u8; 1]; + match client_reader.read_exact(&mut first).await { + Ok(_) => {} + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), + Err(e) => return Err(ProxyError::Io(e)), + } + + let len_words = if (first[0] & 0x7f) == 0x7f { + let mut ext = [0u8; 3]; + client_reader + .read_exact(&mut ext) + .await + .map_err(ProxyError::Io)?; + u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize + } else { + (first[0] & 0x7f) as usize + }; + + len_words + .checked_mul(4) + .ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))? + } + ProtoTag::Intermediate | ProtoTag::Secure => { + let mut len_buf = [0u8; 4]; + match client_reader.read_exact(&mut len_buf).await { + Ok(_) => {} + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), + Err(e) => return Err(ProxyError::Io(e)), + } + (u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize + } + }; + + if len > 16 * 1024 * 1024 { + return Err(ProxyError::Proxy(format!("Frame too large: {len}"))); + } + + let mut payload = vec![0u8; len]; + client_reader + .read_exact(&mut payload) + .await + .map_err(ProxyError::Io)?; + Ok(Some(payload)) +} + +async fn write_client_payload( + client_writer: &mut CryptoWriter, + proto_tag: ProtoTag, + flags: u32, + data: &[u8], +) -> Result<()> +where + W: AsyncWrite + Unpin + Send + 'static, +{ + let quickack = (flags & RPC_FLAG_QUICKACK) != 0; + + match proto_tag { + ProtoTag::Abridged => { + if data.len() % 4 != 0 { + return Err(ProxyError::Proxy(format!( + "Abridged payload must be 4-byte aligned, got {}", + data.len() + ))); + } + + let len_words = data.len() / 4; + if len_words < 0x7f { + let mut first = len_words as u8; + if quickack { + first |= 0x80; + } + client_writer + .write_all(&[first]) + .await + .map_err(ProxyError::Io)?; + } else if len_words < (1 << 24) { + let mut first = 0x7fu8; + if quickack { + first |= 0x80; + } + let lw = (len_words as u32).to_le_bytes(); + client_writer + .write_all(&[first, lw[0], lw[1], lw[2]]) + .await + .map_err(ProxyError::Io)?; + } else { + return Err(ProxyError::Proxy(format!( + "Abridged frame too large: {}", + data.len() + ))); + } + + client_writer + .write_all(data) + .await + .map_err(ProxyError::Io)?; + } + ProtoTag::Intermediate | ProtoTag::Secure => { + let mut len = data.len() as u32; + if quickack { + len |= 0x8000_0000; + } + client_writer + .write_all(&len.to_le_bytes()) + .await + .map_err(ProxyError::Io)?; + client_writer + .write_all(data) + .await + .map_err(ProxyError::Io)?; + } + } + + client_writer.flush().await.map_err(ProxyError::Io) +} + +async fn write_client_ack( + client_writer: &mut CryptoWriter, + proto_tag: ProtoTag, + confirm: u32, +) -> Result<()> +where + W: AsyncWrite + Unpin + Send + 'static, +{ + let bytes = if proto_tag == ProtoTag::Abridged { + confirm.to_be_bytes() + } else { + confirm.to_le_bytes() + }; + client_writer + .write_all(&bytes) + .await + .map_err(ProxyError::Io)?; + client_writer.flush().await.map_err(ProxyError::Io) +} diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 92dd373..d6243aa 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -1,11 +1,13 @@ //! Proxy Defs -pub mod handshake; pub mod client; -pub mod relay; +pub mod direct_relay; +pub mod handshake; pub mod masking; +pub mod middle_relay; +pub mod relay; -pub use handshake::*; pub use client::ClientHandler; +pub use handshake::*; +pub use masking::*; pub use relay::*; -pub use masking::*; \ No newline at end of file diff --git a/src/transport/middle_proxy/codec.rs b/src/transport/middle_proxy/codec.rs new file mode 100644 index 0000000..51daee9 --- /dev/null +++ b/src/transport/middle_proxy/codec.rs @@ -0,0 +1,179 @@ +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +use crate::crypto::{AesCbc, crc32}; +use crate::error::{ProxyError, Result}; +use crate::protocol::constants::*; + +pub(crate) fn build_rpc_frame(seq_no: i32, payload: &[u8]) -> Vec { + let total_len = (4 + 4 + payload.len() + 4) as u32; + let mut frame = Vec::with_capacity(total_len as usize); + frame.extend_from_slice(&total_len.to_le_bytes()); + frame.extend_from_slice(&seq_no.to_le_bytes()); + frame.extend_from_slice(payload); + let c = crc32(&frame); + frame.extend_from_slice(&c.to_le_bytes()); + frame +} + +pub(crate) async fn read_rpc_frame_plaintext( + rd: &mut (impl AsyncReadExt + Unpin), +) -> Result<(i32, Vec)> { + let mut len_buf = [0u8; 4]; + rd.read_exact(&mut len_buf).await.map_err(ProxyError::Io)?; + let total_len = u32::from_le_bytes(len_buf) as usize; + + if !(12..=(1 << 24)).contains(&total_len) { + return Err(ProxyError::InvalidHandshake(format!( + "Bad RPC frame length: {total_len}" + ))); + } + + let mut rest = vec![0u8; total_len - 4]; + rd.read_exact(&mut rest).await.map_err(ProxyError::Io)?; + + let mut full = Vec::with_capacity(total_len); + full.extend_from_slice(&len_buf); + full.extend_from_slice(&rest); + + let crc_offset = total_len - 4; + let expected_crc = u32::from_le_bytes(full[crc_offset..crc_offset + 4].try_into().unwrap()); + let actual_crc = crc32(&full[..crc_offset]); + if expected_crc != actual_crc { + return Err(ProxyError::InvalidHandshake(format!( + "CRC mismatch: 0x{expected_crc:08x} vs 0x{actual_crc:08x}" + ))); + } + + let seq_no = i32::from_le_bytes(full[4..8].try_into().unwrap()); + let payload = full[8..crc_offset].to_vec(); + Ok((seq_no, payload)) +} + +pub(crate) fn build_nonce_payload(key_selector: u32, crypto_ts: u32, nonce: &[u8; 16]) -> [u8; 32] { + let mut p = [0u8; 32]; + p[0..4].copy_from_slice(&RPC_NONCE_U32.to_le_bytes()); + p[4..8].copy_from_slice(&key_selector.to_le_bytes()); + p[8..12].copy_from_slice(&RPC_CRYPTO_AES_U32.to_le_bytes()); + p[12..16].copy_from_slice(&crypto_ts.to_le_bytes()); + p[16..32].copy_from_slice(nonce); + p +} + +pub(crate) fn parse_nonce_payload(d: &[u8]) -> Result<(u32, u32, u32, [u8; 16])> { + if d.len() < 32 { + return Err(ProxyError::InvalidHandshake(format!( + "Nonce payload too short: {} bytes", + d.len() + ))); + } + + let t = u32::from_le_bytes(d[0..4].try_into().unwrap()); + if t != RPC_NONCE_U32 { + return Err(ProxyError::InvalidHandshake(format!( + "Expected RPC_NONCE 0x{RPC_NONCE_U32:08x}, got 0x{t:08x}" + ))); + } + + let key_select = u32::from_le_bytes(d[4..8].try_into().unwrap()); + let schema = u32::from_le_bytes(d[8..12].try_into().unwrap()); + let ts = u32::from_le_bytes(d[12..16].try_into().unwrap()); + let mut nonce = [0u8; 16]; + nonce.copy_from_slice(&d[16..32]); + Ok((key_select, schema, ts, nonce)) +} + +pub(crate) fn build_handshake_payload( + our_ip: [u8; 4], + our_port: u16, + peer_ip: [u8; 4], + peer_port: u16, +) -> [u8; 32] { + let mut p = [0u8; 32]; + p[0..4].copy_from_slice(&RPC_HANDSHAKE_U32.to_le_bytes()); + + // Keep C memory layout compatibility for PID IPv4 bytes. + p[8..12].copy_from_slice(&our_ip); + p[12..14].copy_from_slice(&our_port.to_le_bytes()); + let pid = (std::process::id() & 0xffff) as u16; + p[14..16].copy_from_slice(&pid.to_le_bytes()); + let utime = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() as u32; + p[16..20].copy_from_slice(&utime.to_le_bytes()); + + p[20..24].copy_from_slice(&peer_ip); + p[24..26].copy_from_slice(&peer_port.to_le_bytes()); + p +} + +pub(crate) fn cbc_encrypt_padded( + key: &[u8; 32], + iv: &[u8; 16], + plaintext: &[u8], +) -> Result<(Vec, [u8; 16])> { + let pad = (16 - (plaintext.len() % 16)) % 16; + let mut buf = plaintext.to_vec(); + let pad_pattern: [u8; 4] = [0x04, 0x00, 0x00, 0x00]; + for i in 0..pad { + buf.push(pad_pattern[i % 4]); + } + + let cipher = AesCbc::new(*key, *iv); + cipher + .encrypt_in_place(&mut buf) + .map_err(|e| ProxyError::Crypto(format!("CBC encrypt: {e}")))?; + + let mut new_iv = [0u8; 16]; + if buf.len() >= 16 { + new_iv.copy_from_slice(&buf[buf.len() - 16..]); + } + Ok((buf, new_iv)) +} + +pub(crate) fn cbc_decrypt_inplace( + key: &[u8; 32], + iv: &[u8; 16], + data: &mut [u8], +) -> Result<[u8; 16]> { + let mut new_iv = [0u8; 16]; + if data.len() >= 16 { + new_iv.copy_from_slice(&data[data.len() - 16..]); + } + + AesCbc::new(*key, *iv) + .decrypt_in_place(data) + .map_err(|e| ProxyError::Crypto(format!("CBC decrypt: {e}")))?; + Ok(new_iv) +} + +pub(crate) struct RpcWriter { + pub(crate) writer: tokio::io::WriteHalf, + pub(crate) key: [u8; 32], + pub(crate) iv: [u8; 16], + pub(crate) seq_no: i32, +} + +impl RpcWriter { + pub(crate) async fn send(&mut self, payload: &[u8]) -> Result<()> { + let frame = build_rpc_frame(self.seq_no, payload); + self.seq_no += 1; + + let pad = (16 - (frame.len() % 16)) % 16; + let mut buf = frame; + let pad_pattern: [u8; 4] = [0x04, 0x00, 0x00, 0x00]; + for i in 0..pad { + buf.push(pad_pattern[i % 4]); + } + + let cipher = AesCbc::new(self.key, self.iv); + cipher + .encrypt_in_place(&mut buf) + .map_err(|e| ProxyError::Crypto(format!("{e}")))?; + + if buf.len() >= 16 { + self.iv.copy_from_slice(&buf[buf.len() - 16..]); + } + self.writer.write_all(&buf).await.map_err(ProxyError::Io) + } +} diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs new file mode 100644 index 0000000..d720c86 --- /dev/null +++ b/src/transport/middle_proxy/health.rs @@ -0,0 +1,38 @@ +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use tracing::{debug, info, warn}; + +use crate::crypto::SecureRandom; +use crate::protocol::constants::TG_MIDDLE_PROXIES_FLAT_V4; + +use super::MePool; + +pub async fn me_health_monitor(pool: Arc, rng: Arc, min_connections: usize) { + loop { + tokio::time::sleep(Duration::from_secs(30)).await; + let current = pool.connection_count(); + if current < min_connections { + warn!( + current, + min = min_connections, + "ME pool below minimum, reconnecting..." + ); + let addrs = TG_MIDDLE_PROXIES_FLAT_V4.clone(); + for &(ip, port) in addrs.iter() { + let needed = min_connections.saturating_sub(pool.connection_count()); + if needed == 0 { + break; + } + for _ in 0..needed { + let addr = SocketAddr::new(ip, port); + match pool.connect_one(addr, &rng).await { + Ok(()) => info!(%addr, "ME reconnected"), + Err(e) => debug!(%addr, error = %e, "ME reconnect failed"), + } + } + } + } + } +} diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs new file mode 100644 index 0000000..4906c4b --- /dev/null +++ b/src/transport/middle_proxy/mod.rs @@ -0,0 +1,26 @@ +//! Middle Proxy RPC transport. + +mod codec; +mod health; +mod pool; +mod pool_nat; +mod reader; +mod registry; +mod send; +mod secret; +mod wire; + +use bytes::Bytes; + +pub use health::me_health_monitor; +pub use pool::MePool; +pub use registry::ConnRegistry; +pub use secret::fetch_proxy_secret; +pub use wire::proto_flags_for_tag; + +#[derive(Debug)] +pub enum MeResponse { + Data { flags: u32, data: Bytes }, + Ack(u32), + Close, +} diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs new file mode 100644 index 0000000..650a029 --- /dev/null +++ b/src/transport/middle_proxy/pool.rs @@ -0,0 +1,499 @@ +use std::net::{IpAddr, SocketAddr}; +use std::sync::Arc; +use std::sync::OnceLock; +use std::sync::atomic::AtomicU64; +use std::time::Duration; + +use bytes::BytesMut; +use rand::Rng; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; +use tokio::sync::{Mutex, RwLock}; +use tokio::time::{Instant, timeout}; +use tracing::{debug, info, warn}; + +use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_keys, sha256}; +use crate::error::{ProxyError, Result}; +use crate::protocol::constants::*; + +use super::ConnRegistry; +use super::codec::{ + RpcWriter, build_handshake_payload, build_nonce_payload, build_rpc_frame, cbc_decrypt_inplace, + cbc_encrypt_padded, parse_nonce_payload, read_rpc_frame_plaintext, +}; +use super::reader::reader_loop; +use super::wire::{IpMaterial, extract_ip_material}; + +const ME_ACTIVE_PING_SECS: u64 = 25; +const ME_ACTIVE_PING_JITTER_SECS: i64 = 5; + +pub struct MePool { + pub(super) registry: Arc, + pub(super) writers: Arc>)>>> , + pub(super) rr: AtomicU64, + pub(super) proxy_tag: Option>, + proxy_secret: Vec, + pub(super) nat_ip_cfg: Option, + pub(super) nat_ip_detected: OnceLock, + pub(super) nat_probe: bool, + pub(super) nat_stun: Option, + pool_size: usize, +} + +impl MePool { + pub fn new( + proxy_tag: Option>, + proxy_secret: Vec, + nat_ip: Option, + nat_probe: bool, + nat_stun: Option, + ) -> Arc { + Arc::new(Self { + registry: Arc::new(ConnRegistry::new()), + writers: Arc::new(RwLock::new(Vec::new())), + rr: AtomicU64::new(0), + proxy_tag, + proxy_secret, + nat_ip_cfg: nat_ip, + nat_ip_detected: OnceLock::new(), + nat_probe, + nat_stun, + pool_size: 2, + }) + } + + pub fn has_proxy_tag(&self) -> bool { + self.proxy_tag.is_some() + } + + pub fn translate_our_addr(&self, addr: SocketAddr) -> SocketAddr { + let ip = self.translate_ip_for_nat(addr.ip()); + SocketAddr::new(ip, addr.port()) + } + + pub fn registry(&self) -> &Arc { + &self.registry + } + + fn writers_arc(&self) -> Arc>)>>> + { + self.writers.clone() + } + + fn key_selector(&self) -> u32 { + if self.proxy_secret.len() >= 4 { + u32::from_le_bytes([ + self.proxy_secret[0], + self.proxy_secret[1], + self.proxy_secret[2], + self.proxy_secret[3], + ]) + } else { + 0 + } + } + + pub async fn init(self: &Arc, pool_size: usize, rng: &SecureRandom) -> Result<()> { + let addrs = &*TG_MIDDLE_PROXIES_FLAT_V4; + let ks = self.key_selector(); + info!( + me_servers = addrs.len(), + pool_size, + key_selector = format_args!("0x{ks:08x}"), + secret_len = self.proxy_secret.len(), + "Initializing ME pool" + ); + + for &(ip, port) in addrs.iter() { + for i in 0..pool_size { + let addr = SocketAddr::new(ip, port); + match self.connect_one(addr, rng).await { + Ok(()) => info!(%addr, idx = i, "ME connected"), + Err(e) => warn!(%addr, idx = i, error = %e, "ME connect failed"), + } + } + if self.writers.read().await.len() >= pool_size { + break; + } + } + + if self.writers.read().await.is_empty() { + return Err(ProxyError::Proxy("No ME connections".into())); + } + Ok(()) + } + + pub(crate) async fn connect_one( + self: &Arc, + addr: SocketAddr, + rng: &SecureRandom, + ) -> Result<()> { + let secret = &self.proxy_secret; + if secret.len() < 32 { + return Err(ProxyError::Proxy( + "proxy-secret too short for ME auth".into(), + )); + } + + let stream = timeout( + Duration::from_secs(ME_CONNECT_TIMEOUT_SECS), + TcpStream::connect(addr), + ) + .await + .map_err(|_| ProxyError::ConnectionTimeout { + addr: addr.to_string(), + })? + .map_err(ProxyError::Io)?; + stream.set_nodelay(true).ok(); + + let local_addr = stream.local_addr().map_err(ProxyError::Io)?; + let peer_addr = stream.peer_addr().map_err(ProxyError::Io)?; + let _ = self.maybe_detect_nat_ip(local_addr.ip()).await; + let reflected = if self.nat_probe { + self.maybe_reflect_public_addr().await + } else { + None + }; + let local_addr_nat = self.translate_our_addr_with_reflection(local_addr, reflected); + let peer_addr_nat = + SocketAddr::new(self.translate_ip_for_nat(peer_addr.ip()), peer_addr.port()); + let (mut rd, mut wr) = tokio::io::split(stream); + + let my_nonce: [u8; 16] = rng.bytes(16).try_into().unwrap(); + let crypto_ts = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() as u32; + + let ks = self.key_selector(); + let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce); + let nonce_frame = build_rpc_frame(-2, &nonce_payload); + let dump = hex_dump(&nonce_frame[..nonce_frame.len().min(44)]); + info!( + key_selector = format_args!("0x{ks:08x}"), + crypto_ts, + frame_len = nonce_frame.len(), + nonce_frame_hex = %dump, + "Sending ME nonce frame" + ); + wr.write_all(&nonce_frame).await.map_err(ProxyError::Io)?; + wr.flush().await.map_err(ProxyError::Io)?; + + let (srv_seq, srv_nonce_payload) = timeout( + Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS), + read_rpc_frame_plaintext(&mut rd), + ) + .await + .map_err(|_| ProxyError::TgHandshakeTimeout)??; + + if srv_seq != -2 { + return Err(ProxyError::InvalidHandshake(format!( + "Expected seq=-2, got {srv_seq}" + ))); + } + + let (srv_key_select, schema, srv_ts, srv_nonce) = parse_nonce_payload(&srv_nonce_payload)?; + if schema != RPC_CRYPTO_AES_U32 { + warn!(schema = format_args!("0x{schema:08x}"), "Unsupported ME crypto schema"); + return Err(ProxyError::InvalidHandshake(format!( + "Unsupported crypto schema: 0x{schema:x}" + ))); + } + + if srv_key_select != ks { + return Err(ProxyError::InvalidHandshake(format!( + "Server key_select 0x{srv_key_select:08x} != client 0x{ks:08x}" + ))); + } + + let skew = crypto_ts.abs_diff(srv_ts); + if skew > 30 { + return Err(ProxyError::InvalidHandshake(format!( + "nonce crypto_ts skew too large: client={crypto_ts}, server={srv_ts}, skew={skew}s" + ))); + } + + info!( + %local_addr, + %local_addr_nat, + reflected_ip = reflected.map(|r| r.ip()).as_ref().map(ToString::to_string), + %peer_addr, + %peer_addr_nat, + key_selector = format_args!("0x{ks:08x}"), + crypto_schema = format_args!("0x{schema:08x}"), + skew_secs = skew, + "ME key derivation parameters" + ); + + let ts_bytes = crypto_ts.to_le_bytes(); + let server_port_bytes = peer_addr_nat.port().to_le_bytes(); + let client_port_bytes = local_addr_nat.port().to_le_bytes(); + + let server_ip = extract_ip_material(peer_addr_nat); + let client_ip = extract_ip_material(local_addr_nat); + + let (srv_ip_opt, clt_ip_opt, clt_v6_opt, srv_v6_opt, hs_our_ip, hs_peer_ip) = + match (server_ip, client_ip) { + (IpMaterial::V4(srv), IpMaterial::V4(clt)) => { + (Some(srv), Some(clt), None, None, clt, srv) + } + (IpMaterial::V6(srv), IpMaterial::V6(clt)) => { + let zero = [0u8; 4]; + (None, None, Some(clt), Some(srv), zero, zero) + } + _ => { + return Err(ProxyError::InvalidHandshake( + "mixed IPv4/IPv6 endpoints are not supported for ME key derivation" + .to_string(), + )); + } + }; + + let diag_level: u8 = std::env::var("ME_DIAG") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(0); + + let prekey_client = build_middleproxy_prekey( + &srv_nonce, + &my_nonce, + &ts_bytes, + srv_ip_opt.as_ref().map(|x| &x[..]), + &client_port_bytes, + b"CLIENT", + clt_ip_opt.as_ref().map(|x| &x[..]), + &server_port_bytes, + secret, + clt_v6_opt.as_ref(), + srv_v6_opt.as_ref(), + ); + let prekey_server = build_middleproxy_prekey( + &srv_nonce, + &my_nonce, + &ts_bytes, + srv_ip_opt.as_ref().map(|x| &x[..]), + &client_port_bytes, + b"SERVER", + clt_ip_opt.as_ref().map(|x| &x[..]), + &server_port_bytes, + secret, + clt_v6_opt.as_ref(), + srv_v6_opt.as_ref(), + ); + + let (wk, wi) = derive_middleproxy_keys( + &srv_nonce, + &my_nonce, + &ts_bytes, + srv_ip_opt.as_ref().map(|x| &x[..]), + &client_port_bytes, + b"CLIENT", + clt_ip_opt.as_ref().map(|x| &x[..]), + &server_port_bytes, + secret, + clt_v6_opt.as_ref(), + srv_v6_opt.as_ref(), + ); + let (rk, ri) = derive_middleproxy_keys( + &srv_nonce, + &my_nonce, + &ts_bytes, + srv_ip_opt.as_ref().map(|x| &x[..]), + &client_port_bytes, + b"SERVER", + clt_ip_opt.as_ref().map(|x| &x[..]), + &server_port_bytes, + secret, + clt_v6_opt.as_ref(), + srv_v6_opt.as_ref(), + ); + + let hs_payload = + build_handshake_payload(hs_our_ip, local_addr.port(), hs_peer_ip, peer_addr.port()); + let hs_frame = build_rpc_frame(-1, &hs_payload); + if diag_level >= 1 { + info!( + write_key = %hex_dump(&wk), + write_iv = %hex_dump(&wi), + read_key = %hex_dump(&rk), + read_iv = %hex_dump(&ri), + srv_ip = %srv_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(), + clt_ip = %clt_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(), + srv_port = %hex_dump(&server_port_bytes), + clt_port = %hex_dump(&client_port_bytes), + crypto_ts = %hex_dump(&ts_bytes), + nonce_srv = %hex_dump(&srv_nonce), + nonce_clt = %hex_dump(&my_nonce), + prekey_sha256_client = %hex_dump(&sha256(&prekey_client)), + prekey_sha256_server = %hex_dump(&sha256(&prekey_server)), + hs_plain = %hex_dump(&hs_frame), + proxy_secret_sha256 = %hex_dump(&sha256(secret)), + "ME diag: derived keys and handshake plaintext" + ); + } + if diag_level >= 2 { + info!( + prekey_client = %hex_dump(&prekey_client), + prekey_server = %hex_dump(&prekey_server), + "ME diag: full prekey buffers" + ); + } + + let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?; + if diag_level >= 1 { + info!( + hs_cipher = %hex_dump(&encrypted_hs), + "ME diag: handshake ciphertext" + ); + } + wr.write_all(&encrypted_hs).await.map_err(ProxyError::Io)?; + wr.flush().await.map_err(ProxyError::Io)?; + + let deadline = Instant::now() + Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS); + let mut enc_buf = BytesMut::with_capacity(256); + let mut dec_buf = BytesMut::with_capacity(256); + let mut read_iv = ri; + let mut handshake_ok = false; + + while Instant::now() < deadline && !handshake_ok { + let remaining = deadline - Instant::now(); + let mut tmp = [0u8; 256]; + let n = match timeout(remaining, rd.read(&mut tmp)).await { + Ok(Ok(0)) => { + return Err(ProxyError::Io(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "ME closed during handshake", + ))); + } + Ok(Ok(n)) => n, + Ok(Err(e)) => return Err(ProxyError::Io(e)), + Err(_) => return Err(ProxyError::TgHandshakeTimeout), + }; + + enc_buf.extend_from_slice(&tmp[..n]); + + let blocks = enc_buf.len() / 16 * 16; + if blocks > 0 { + let mut chunk = vec![0u8; blocks]; + chunk.copy_from_slice(&enc_buf[..blocks]); + read_iv = cbc_decrypt_inplace(&rk, &read_iv, &mut chunk)?; + dec_buf.extend_from_slice(&chunk); + let _ = enc_buf.split_to(blocks); + } + + while dec_buf.len() >= 4 { + let fl = u32::from_le_bytes(dec_buf[0..4].try_into().unwrap()) as usize; + + if fl == 4 { + let _ = dec_buf.split_to(4); + continue; + } + if !(12..=(1 << 24)).contains(&fl) { + return Err(ProxyError::InvalidHandshake(format!( + "Bad HS response frame len: {fl}" + ))); + } + if dec_buf.len() < fl { + break; + } + + let frame = dec_buf.split_to(fl); + let pe = fl - 4; + let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); + let ac = crate::crypto::crc32(&frame[..pe]); + if ec != ac { + return Err(ProxyError::InvalidHandshake(format!( + "HS CRC mismatch: 0x{ec:08x} vs 0x{ac:08x}" + ))); + } + + let hs_type = u32::from_le_bytes(frame[8..12].try_into().unwrap()); + if hs_type == RPC_HANDSHAKE_ERROR_U32 { + let err_code = if frame.len() >= 16 { + i32::from_le_bytes(frame[12..16].try_into().unwrap()) + } else { + -1 + }; + return Err(ProxyError::InvalidHandshake(format!( + "ME rejected handshake (error={err_code})" + ))); + } + if hs_type != RPC_HANDSHAKE_U32 { + return Err(ProxyError::InvalidHandshake(format!( + "Expected HANDSHAKE 0x{RPC_HANDSHAKE_U32:08x}, got 0x{hs_type:08x}" + ))); + } + + handshake_ok = true; + break; + } + } + + if !handshake_ok { + return Err(ProxyError::TgHandshakeTimeout); + } + + info!(%addr, "RPC handshake OK"); + + let rpc_w = Arc::new(Mutex::new(RpcWriter { + writer: wr, + key: wk, + iv: write_iv, + seq_no: 0, + })); + self.writers.write().await.push((addr, rpc_w.clone())); + + let reg = self.registry.clone(); + let w_pong = rpc_w.clone(); + let w_pool = self.writers_arc(); + let w_ping = rpc_w.clone(); + let w_pool_ping = self.writers_arc(); + tokio::spawn(async move { + if let Err(e) = + reader_loop(rd, rk, read_iv, reg, enc_buf, dec_buf, w_pong.clone()).await + { + warn!(error = %e, "ME reader ended"); + } + let mut ws = w_pool.write().await; + ws.retain(|(_, w)| !Arc::ptr_eq(w, &w_pong)); + info!(remaining = ws.len(), "Dead ME writer removed from pool"); + }); + tokio::spawn(async move { + let mut ping_id: i64 = rand::random::(); + loop { + let jitter = rand::rng() + .random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); + let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; + tokio::time::sleep(Duration::from_secs(wait)).await; + let mut p = Vec::with_capacity(12); + p.extend_from_slice(&RPC_PING_U32.to_le_bytes()); + p.extend_from_slice(&ping_id.to_le_bytes()); + ping_id = ping_id.wrapping_add(1); + if let Err(e) = w_ping.lock().await.send(&p).await { + debug!(error = %e, "Active ME ping failed, removing dead writer"); + let mut ws = w_pool_ping.write().await; + ws.retain(|(_, w)| !Arc::ptr_eq(w, &w_ping)); + break; + } + } + }); + + Ok(()) + } + +} + +fn hex_dump(data: &[u8]) -> String { + const MAX: usize = 64; + let mut out = String::with_capacity(data.len() * 2 + 3); + for (i, b) in data.iter().take(MAX).enumerate() { + if i > 0 { + out.push(' '); + } + out.push_str(&format!("{b:02x}")); + } + if data.len() > MAX { + out.push_str(" …"); + } + out +} diff --git a/src/transport/middle_proxy/pool_nat.rs b/src/transport/middle_proxy/pool_nat.rs new file mode 100644 index 0000000..2a37ec4 --- /dev/null +++ b/src/transport/middle_proxy/pool_nat.rs @@ -0,0 +1,200 @@ +use std::net::{IpAddr, Ipv4Addr}; + +use tracing::{info, warn}; + +use crate::error::{ProxyError, Result}; + +use super::MePool; + +impl MePool { + pub(super) fn translate_ip_for_nat(&self, ip: IpAddr) -> IpAddr { + let nat_ip = self + .nat_ip_cfg + .or_else(|| self.nat_ip_detected.get().copied()); + + let Some(nat_ip) = nat_ip else { + return ip; + }; + + match (ip, nat_ip) { + (IpAddr::V4(src), IpAddr::V4(dst)) + if is_privateish(IpAddr::V4(src)) + || src.is_loopback() + || src.is_unspecified() => + { + IpAddr::V4(dst) + } + (IpAddr::V6(src), IpAddr::V6(dst)) if src.is_loopback() || src.is_unspecified() => { + IpAddr::V6(dst) + } + (orig, _) => orig, + } + } + + pub(super) fn translate_our_addr_with_reflection( + &self, + addr: std::net::SocketAddr, + reflected: Option, + ) -> std::net::SocketAddr { + let ip = if let Some(r) = reflected { + // Use reflected IP (not port) only when local address is non-public. + if is_privateish(addr.ip()) || addr.ip().is_loopback() || addr.ip().is_unspecified() { + r.ip() + } else { + self.translate_ip_for_nat(addr.ip()) + } + } else { + self.translate_ip_for_nat(addr.ip()) + }; + + // Keep the kernel-assigned TCP source port; STUN port can differ. + std::net::SocketAddr::new(ip, addr.port()) + } + + pub(super) async fn maybe_detect_nat_ip(&self, local_ip: IpAddr) -> Option { + if self.nat_ip_cfg.is_some() { + return self.nat_ip_cfg; + } + + if !(is_privateish(local_ip) || local_ip.is_loopback() || local_ip.is_unspecified()) { + return None; + } + + if let Some(ip) = self.nat_ip_detected.get().copied() { + return Some(ip); + } + + match fetch_public_ipv4().await { + Ok(Some(ip)) => { + let _ = self.nat_ip_detected.set(IpAddr::V4(ip)); + info!(public_ip = %ip, "Auto-detected public IP for NAT translation"); + Some(IpAddr::V4(ip)) + } + Ok(None) => None, + Err(e) => { + warn!(error = %e, "Failed to auto-detect public IP"); + None + } + } + } + + pub(super) async fn maybe_reflect_public_addr(&self) -> Option { + let stun_addr = self + .nat_stun + .clone() + .unwrap_or_else(|| "stun.l.google.com:19302".to_string()); + match fetch_stun_binding(&stun_addr).await { + Ok(sa) => { + if let Some(sa) = sa { + info!(%sa, "NAT probe: reflected address"); + } + sa + } + Err(e) => { + warn!(error = %e, "NAT probe failed"); + None + } + } + } +} + +async fn fetch_public_ipv4() -> Result> { + let res = reqwest::get("https://checkip.amazonaws.com").await.map_err(|e| { + ProxyError::Proxy(format!("public IP detection request failed: {e}")) + })?; + + let text = res.text().await.map_err(|e| { + ProxyError::Proxy(format!("public IP detection read failed: {e}")) + })?; + + let ip = text.trim().parse().ok(); + Ok(ip) +} + +async fn fetch_stun_binding(stun_addr: &str) -> Result> { + use rand::RngCore; + use tokio::net::UdpSocket; + + let socket = UdpSocket::bind("0.0.0.0:0") + .await + .map_err(|e| ProxyError::Proxy(format!("STUN bind failed: {e}")))?; + socket + .connect(stun_addr) + .await + .map_err(|e| ProxyError::Proxy(format!("STUN connect failed: {e}")))?; + + // Build minimal Binding Request. + let mut req = vec![0u8; 20]; + req[0..2].copy_from_slice(&0x0001u16.to_be_bytes()); // Binding Request + req[2..4].copy_from_slice(&0u16.to_be_bytes()); // length + req[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes()); // magic cookie + rand::thread_rng().fill_bytes(&mut req[8..20]); + + socket + .send(&req) + .await + .map_err(|e| ProxyError::Proxy(format!("STUN send failed: {e}")))?; + + let mut buf = [0u8; 128]; + let n = socket + .recv(&mut buf) + .await + .map_err(|e| ProxyError::Proxy(format!("STUN recv failed: {e}")))?; + if n < 20 { + return Ok(None); + } + + // Parse attributes. + let mut idx = 20; + while idx + 4 <= n { + let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().unwrap()); + let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().unwrap()) as usize; + idx += 4; + if idx + alen > n { + break; + } + match atype { + 0x0020 /* XOR-MAPPED-ADDRESS */ | 0x0001 /* MAPPED-ADDRESS */ => { + if alen < 8 { + break; + } + let family = buf[idx + 1]; + if family != 0x01 { + // only IPv4 supported here + break; + } + let port_bytes = [buf[idx + 2], buf[idx + 3]]; + let ip_bytes = [buf[idx + 4], buf[idx + 5], buf[idx + 6], buf[idx + 7]]; + + let (port, ip) = if atype == 0x0020 { + let magic = 0x2112A442u32.to_be_bytes(); + let port = u16::from_be_bytes(port_bytes) ^ ((magic[0] as u16) << 8 | magic[1] as u16); + let ip = [ + ip_bytes[0] ^ magic[0], + ip_bytes[1] ^ magic[1], + ip_bytes[2] ^ magic[2], + ip_bytes[3] ^ magic[3], + ]; + (port, ip) + } else { + (u16::from_be_bytes(port_bytes), ip_bytes) + }; + return Ok(Some(std::net::SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3])), + port, + ))); + } + _ => {} + } + idx += (alen + 3) & !3; // 4-byte alignment + } + + Ok(None) +} + +fn is_privateish(ip: IpAddr) -> bool { + match ip { + IpAddr::V4(v4) => v4.is_private() || v4.is_link_local(), + IpAddr::V6(v6) => v6.is_unique_local(), + } +} diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs new file mode 100644 index 0000000..83df742 --- /dev/null +++ b/src/transport/middle_proxy/reader.rs @@ -0,0 +1,141 @@ +use std::sync::Arc; + +use bytes::{Bytes, BytesMut}; +use tokio::io::AsyncReadExt; +use tokio::net::TcpStream; +use tokio::sync::Mutex; +use tracing::{debug, trace, warn}; + +use crate::crypto::{AesCbc, crc32}; +use crate::error::{ProxyError, Result}; +use crate::protocol::constants::*; + +use super::codec::RpcWriter; +use super::{ConnRegistry, MeResponse}; + +pub(crate) async fn reader_loop( + mut rd: tokio::io::ReadHalf, + dk: [u8; 32], + mut div: [u8; 16], + reg: Arc, + enc_leftover: BytesMut, + mut dec: BytesMut, + writer: Arc>, +) -> Result<()> { + let mut raw = enc_leftover; + + loop { + let mut tmp = [0u8; 16_384]; + let n = rd.read(&mut tmp).await.map_err(ProxyError::Io)?; + if n == 0 { + return Ok(()); + } + raw.extend_from_slice(&tmp[..n]); + + let blocks = raw.len() / 16 * 16; + if blocks > 0 { + let mut new_iv = [0u8; 16]; + new_iv.copy_from_slice(&raw[blocks - 16..blocks]); + + let mut chunk = vec![0u8; blocks]; + chunk.copy_from_slice(&raw[..blocks]); + AesCbc::new(dk, div) + .decrypt_in_place(&mut chunk) + .map_err(|e| ProxyError::Crypto(format!("{e}")))?; + div = new_iv; + dec.extend_from_slice(&chunk); + let _ = raw.split_to(blocks); + } + + while dec.len() >= 12 { + let fl = u32::from_le_bytes(dec[0..4].try_into().unwrap()) as usize; + if fl == 4 { + let _ = dec.split_to(4); + continue; + } + if !(12..=(1 << 24)).contains(&fl) { + warn!(frame_len = fl, "Invalid RPC frame len"); + dec.clear(); + break; + } + if dec.len() < fl { + break; + } + + let frame = dec.split_to(fl); + let pe = fl - 4; + let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); + if crc32(&frame[..pe]) != ec { + warn!("CRC mismatch in data frame"); + continue; + } + + let payload = &frame[8..pe]; + if payload.len() < 4 { + continue; + } + + let pt = u32::from_le_bytes(payload[0..4].try_into().unwrap()); + let body = &payload[4..]; + + if pt == RPC_PROXY_ANS_U32 && body.len() >= 12 { + let flags = u32::from_le_bytes(body[0..4].try_into().unwrap()); + let cid = u64::from_le_bytes(body[4..12].try_into().unwrap()); + let data = Bytes::copy_from_slice(&body[12..]); + trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS"); + + let routed = reg.route(cid, MeResponse::Data { flags, data }).await; + if !routed { + reg.unregister(cid).await; + send_close_conn(&writer, cid).await; + } + } else if pt == RPC_SIMPLE_ACK_U32 && body.len() >= 12 { + let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); + let cfm = u32::from_le_bytes(body[8..12].try_into().unwrap()); + trace!(cid, cfm, "RPC_SIMPLE_ACK"); + + let routed = reg.route(cid, MeResponse::Ack(cfm)).await; + if !routed { + reg.unregister(cid).await; + send_close_conn(&writer, cid).await; + } + } else if pt == RPC_CLOSE_EXT_U32 && body.len() >= 8 { + let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); + debug!(cid, "RPC_CLOSE_EXT from ME"); + reg.route(cid, MeResponse::Close).await; + reg.unregister(cid).await; + } else if pt == RPC_CLOSE_CONN_U32 && body.len() >= 8 { + let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); + debug!(cid, "RPC_CLOSE_CONN from ME"); + reg.route(cid, MeResponse::Close).await; + reg.unregister(cid).await; + } else if pt == RPC_PING_U32 && body.len() >= 8 { + let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap()); + trace!(ping_id, "RPC_PING -> RPC_PONG"); + let mut pong = Vec::with_capacity(12); + pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes()); + pong.extend_from_slice(&ping_id.to_le_bytes()); + if let Err(e) = writer.lock().await.send(&pong).await { + warn!(error = %e, "PONG send failed"); + break; + } + } else { + debug!( + rpc_type = format_args!("0x{pt:08x}"), + len = body.len(), + "Unknown RPC" + ); + } + } + } +} + +async fn send_close_conn(writer: &Arc>, conn_id: u64) { + let mut p = Vec::with_capacity(12); + p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes()); + p.extend_from_slice(&conn_id.to_le_bytes()); + + if let Err(e) = writer.lock().await.send(&p).await { + debug!(conn_id, error = %e, "Failed to send RPC_CLOSE_CONN"); + } +} diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs new file mode 100644 index 0000000..1f13025 --- /dev/null +++ b/src/transport/middle_proxy/registry.rs @@ -0,0 +1,42 @@ +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; + +use tokio::sync::{RwLock, mpsc}; + +use super::MeResponse; + +pub struct ConnRegistry { + map: RwLock>>, + next_id: AtomicU64, +} + +impl ConnRegistry { + pub fn new() -> Self { + // Avoid fully predictable conn_id sequence from 1. + let start = rand::random::() | 1; + Self { + map: RwLock::new(HashMap::new()), + next_id: AtomicU64::new(start), + } + } + + pub async fn register(&self) -> (u64, mpsc::Receiver) { + let id = self.next_id.fetch_add(1, Ordering::Relaxed); + let (tx, rx) = mpsc::channel(256); + self.map.write().await.insert(id, tx); + (id, rx) + } + + pub async fn unregister(&self, id: u64) { + self.map.write().await.remove(&id); + } + + pub async fn route(&self, id: u64, resp: MeResponse) -> bool { + let m = self.map.read().await; + if let Some(tx) = m.get(&id) { + tx.send(resp).await.is_ok() + } else { + false + } + } +} diff --git a/src/transport/middle_proxy/secret.rs b/src/transport/middle_proxy/secret.rs new file mode 100644 index 0000000..b998411 --- /dev/null +++ b/src/transport/middle_proxy/secret.rs @@ -0,0 +1,81 @@ +use std::time::Duration; + +use tracing::{debug, info, warn}; + +use crate::error::{ProxyError, Result}; + +/// Fetch Telegram proxy-secret binary. +pub async fn fetch_proxy_secret(cache_path: Option<&str>) -> Result> { + let cache = cache_path.unwrap_or("proxy-secret"); + + // 1) Try fresh download first. + match download_proxy_secret().await { + Ok(data) => { + if let Err(e) = tokio::fs::write(cache, &data).await { + warn!(error = %e, "Failed to cache proxy-secret (non-fatal)"); + } else { + debug!(path = cache, len = data.len(), "Cached proxy-secret"); + } + return Ok(data); + } + Err(download_err) => { + warn!(error = %download_err, "Proxy-secret download failed, trying cache/file fallback"); + // Fall through to cache/file. + } + } + + // 2) Fallback to cache/file regardless of age; require len>=32. + match tokio::fs::read(cache).await { + Ok(data) if data.len() >= 32 => { + let age_hours = tokio::fs::metadata(cache) + .await + .ok() + .and_then(|m| m.modified().ok()) + .and_then(|m| std::time::SystemTime::now().duration_since(m).ok()) + .map(|d| d.as_secs() / 3600); + info!( + path = cache, + len = data.len(), + age_hours, + "Loaded proxy-secret from cache/file after download failure" + ); + Ok(data) + } + Ok(data) => Err(ProxyError::Proxy(format!( + "Cached proxy-secret too short: {} bytes (need >= 32)", + data.len() + ))), + Err(e) => Err(ProxyError::Proxy(format!( + "Failed to read proxy-secret cache after download failure: {e}" + ))), + } +} + +async fn download_proxy_secret() -> Result> { + let resp = reqwest::get("https://core.telegram.org/getProxySecret") + .await + .map_err(|e| ProxyError::Proxy(format!("Failed to download proxy-secret: {e}")))?; + + if !resp.status().is_success() { + return Err(ProxyError::Proxy(format!( + "proxy-secret download HTTP {}", + resp.status() + ))); + } + + let data = resp + .bytes() + .await + .map_err(|e| ProxyError::Proxy(format!("Read proxy-secret body: {e}")))? + .to_vec(); + + if data.len() < 32 { + return Err(ProxyError::Proxy(format!( + "proxy-secret too short: {} bytes (need >= 32)", + data.len() + ))); + } + + info!(len = data.len(), "Downloaded proxy-secret OK"); + Ok(data) +} diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs new file mode 100644 index 0000000..192a560 --- /dev/null +++ b/src/transport/middle_proxy/send.rs @@ -0,0 +1,146 @@ +use std::net::SocketAddr; +use std::sync::Arc; +use std::sync::atomic::Ordering; + +use tokio::sync::Mutex; +use tracing::{debug, warn}; + +use crate::error::{ProxyError, Result}; +use crate::protocol::constants::{RPC_CLOSE_EXT_U32, TG_MIDDLE_PROXIES_V4}; + +use super::MePool; +use super::codec::RpcWriter; +use super::wire::build_proxy_req_payload; + +impl MePool { + pub async fn send_proxy_req( + &self, + conn_id: u64, + target_dc: i16, + client_addr: SocketAddr, + our_addr: SocketAddr, + data: &[u8], + proto_flags: u32, + ) -> Result<()> { + let payload = build_proxy_req_payload( + conn_id, + client_addr, + our_addr, + data, + self.proxy_tag.as_deref(), + proto_flags, + ); + + loop { + let ws = self.writers.read().await; + if ws.is_empty() { + return Err(ProxyError::Proxy("All ME connections dead".into())); + } + let writers: Vec<(SocketAddr, Arc>)> = ws.iter().cloned().collect(); + drop(ws); + + let candidate_indices = candidate_indices_for_dc(&writers, target_dc); + if candidate_indices.is_empty() { + return Err(ProxyError::Proxy("No ME writers available for target DC".into())); + } + let start = self.rr.fetch_add(1, Ordering::Relaxed) as usize % candidate_indices.len(); + + // Prefer immediately available writer to avoid waiting on stalled connection. + for offset in 0..candidate_indices.len() { + let cidx = (start + offset) % candidate_indices.len(); + let idx = candidate_indices[cidx]; + let w = writers[idx].1.clone(); + if let Ok(mut guard) = w.try_lock() { + let send_res = guard.send(&payload).await; + drop(guard); + match send_res { + Ok(()) => return Ok(()), + Err(e) => { + warn!(error = %e, "ME write failed, removing dead conn"); + let mut ws = self.writers.write().await; + ws.retain(|(_, o)| !Arc::ptr_eq(o, &w)); + if ws.is_empty() { + return Err(ProxyError::Proxy("All ME connections dead".into())); + } + continue; + } + } + } + } + + // All writers are currently busy, wait for the selected one. + let w = writers[candidate_indices[start]].1.clone(); + match w.lock().await.send(&payload).await { + Ok(()) => return Ok(()), + Err(e) => { + warn!(error = %e, "ME write failed, removing dead conn"); + let mut ws = self.writers.write().await; + ws.retain(|(_, o)| !Arc::ptr_eq(o, &w)); + if ws.is_empty() { + return Err(ProxyError::Proxy("All ME connections dead".into())); + } + } + } + } + } + + pub async fn send_close(&self, conn_id: u64) -> Result<()> { + let ws = self.writers.read().await; + if !ws.is_empty() { + let w = ws[0].1.clone(); + drop(ws); + let mut p = Vec::with_capacity(12); + p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes()); + p.extend_from_slice(&conn_id.to_le_bytes()); + if let Err(e) = w.lock().await.send(&p).await { + debug!(error = %e, "ME close write failed"); + let mut ws = self.writers.write().await; + ws.retain(|(_, o)| !Arc::ptr_eq(o, &w)); + } + } + + self.registry.unregister(conn_id).await; + Ok(()) + } + + pub fn connection_count(&self) -> usize { + self.writers.try_read().map(|w| w.len()).unwrap_or(0) + } +} + +fn candidate_indices_for_dc( + writers: &[(SocketAddr, Arc>)], + target_dc: i16, +) -> Vec { + let mut preferred = Vec::::new(); + let key = target_dc as i32; + if let Some(v) = TG_MIDDLE_PROXIES_V4.get(&key) { + preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port))); + } + if preferred.is_empty() { + let abs = key.abs(); + if let Some(v) = TG_MIDDLE_PROXIES_V4.get(&abs) { + preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port))); + } + } + if preferred.is_empty() { + let abs = key.abs(); + if let Some(v) = TG_MIDDLE_PROXIES_V4.get(&-abs) { + preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port))); + } + } + if preferred.is_empty() { + return (0..writers.len()).collect(); + } + + let mut out = Vec::new(); + for (idx, (addr, _)) in writers.iter().enumerate() { + if preferred.iter().any(|p| p == addr) { + out.push(idx); + } + } + if out.is_empty() { + return (0..writers.len()).collect(); + } + out +} diff --git a/src/transport/middle_proxy/wire.rs b/src/transport/middle_proxy/wire.rs new file mode 100644 index 0000000..1ed9727 --- /dev/null +++ b/src/transport/middle_proxy/wire.rs @@ -0,0 +1,106 @@ +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + +use crate::protocol::constants::*; + +#[derive(Clone, Copy)] +pub(crate) enum IpMaterial { + V4([u8; 4]), + V6([u8; 16]), +} + +pub(crate) fn extract_ip_material(addr: SocketAddr) -> IpMaterial { + match addr.ip() { + IpAddr::V4(v4) => IpMaterial::V4(v4.octets()), + IpAddr::V6(v6) => { + if let Some(v4) = v6.to_ipv4_mapped() { + IpMaterial::V4(v4.octets()) + } else { + IpMaterial::V6(v6.octets()) + } + } + } +} + +fn ipv4_to_mapped_v6_c_compat(ip: Ipv4Addr) -> [u8; 16] { + let mut buf = [0u8; 16]; + + // Matches tl_store_long(0) + tl_store_int(-0x10000). + buf[8..12].copy_from_slice(&(-0x10000i32).to_le_bytes()); + + // Matches tl_store_int(htonl(remote_ip_host_order)). + let host_order = u32::from_ne_bytes(ip.octets()); + let network_order = host_order.to_be(); + buf[12..16].copy_from_slice(&network_order.to_le_bytes()); + + buf +} + +fn append_mapped_addr_and_port(buf: &mut Vec, addr: SocketAddr) { + match addr.ip() { + IpAddr::V4(v4) => buf.extend_from_slice(&ipv4_to_mapped_v6_c_compat(v4)), + IpAddr::V6(v6) => buf.extend_from_slice(&v6.octets()), + } + buf.extend_from_slice(&(addr.port() as u32).to_le_bytes()); +} + +pub(crate) fn build_proxy_req_payload( + conn_id: u64, + client_addr: SocketAddr, + our_addr: SocketAddr, + data: &[u8], + proxy_tag: Option<&[u8]>, + proto_flags: u32, +) -> Vec { + let mut b = Vec::with_capacity(128 + data.len()); + + b.extend_from_slice(&RPC_PROXY_REQ_U32.to_le_bytes()); + b.extend_from_slice(&proto_flags.to_le_bytes()); + b.extend_from_slice(&conn_id.to_le_bytes()); + + append_mapped_addr_and_port(&mut b, client_addr); + append_mapped_addr_and_port(&mut b, our_addr); + + if proto_flags & 12 != 0 { + let extra_start = b.len(); + b.extend_from_slice(&0u32.to_le_bytes()); + + if let Some(tag) = proxy_tag { + b.extend_from_slice(&TL_PROXY_TAG_U32.to_le_bytes()); + + if tag.len() < 254 { + b.push(tag.len() as u8); + b.extend_from_slice(tag); + let pad = (4 - ((1 + tag.len()) % 4)) % 4; + b.extend(std::iter::repeat_n(0u8, pad)); + } else { + b.push(0xfe); + let len_bytes = (tag.len() as u32).to_le_bytes(); + b.extend_from_slice(&len_bytes[..3]); + b.extend_from_slice(tag); + let pad = (4 - (tag.len() % 4)) % 4; + b.extend(std::iter::repeat_n(0u8, pad)); + } + } + + let extra_bytes = (b.len() - extra_start - 4) as u32; + b[extra_start..extra_start + 4].copy_from_slice(&extra_bytes.to_le_bytes()); + } + + b.extend_from_slice(data); + b +} + +pub fn proto_flags_for_tag(tag: crate::protocol::constants::ProtoTag, has_proxy_tag: bool) -> u32 { + use crate::protocol::constants::ProtoTag; + + let mut flags = RPC_FLAG_MAGIC | RPC_FLAG_EXTMODE2; + if has_proxy_tag { + flags |= RPC_FLAG_HAS_AD_TAG; + } + + match tag { + ProtoTag::Abridged => flags | RPC_FLAG_ABRIDGED, + ProtoTag::Intermediate => flags | RPC_FLAG_INTERMEDIATE, + ProtoTag::Secure => flags | RPC_FLAG_PAD | RPC_FLAG_INTERMEDIATE, + } +} diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 2b507d5..51cffa4 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -10,4 +10,5 @@ pub use pool::ConnectionPool; pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol}; pub use socket::*; pub use socks::*; -pub use upstream::{UpstreamManager, StartupPingResult, DcPingResult}; \ No newline at end of file +pub use upstream::{DcPingResult, StartupPingResult, UpstreamManager}; +pub mod middle_proxy; diff --git a/src/transport/upstream.rs b/src/transport/upstream.rs index 6adf452..4b5fe9c 100644 --- a/src/transport/upstream.rs +++ b/src/transport/upstream.rs @@ -1,4 +1,6 @@ //! Upstream Management with per-DC latency-weighted selection +//! +//! IPv6/IPv4 connectivity checks with configurable preference. use std::net::{SocketAddr, IpAddr}; use std::sync::Arc; @@ -18,6 +20,9 @@ use crate::transport::socks::{connect_socks4, connect_socks5}; /// Number of Telegram datacenters const NUM_DCS: usize = 5; +/// Timeout for individual DC ping attempt +const DC_PING_TIMEOUT_SECS: u64 = 5; + // ============= RTT Tracking ============= #[derive(Debug, Clone, Copy)] @@ -30,19 +35,42 @@ impl LatencyEma { const fn new(alpha: f64) -> Self { Self { value_ms: None, alpha } } - + fn update(&mut self, sample_ms: f64) { self.value_ms = Some(match self.value_ms { None => sample_ms, Some(prev) => prev * (1.0 - self.alpha) + sample_ms * self.alpha, }); } - + fn get(&self) -> Option { self.value_ms } } +// ============= Per-DC IP Preference Tracking ============= + +/// Tracks which IP version works for each DC +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum IpPreference { + /// Not yet tested + Unknown, + /// IPv6 works + PreferV6, + /// Only IPv4 works (IPv6 failed) + PreferV4, + /// Both work + BothWork, + /// Both failed + Unavailable, +} + +impl Default for IpPreference { + fn default() -> Self { + Self::Unknown + } +} + // ============= Upstream State ============= #[derive(Debug)] @@ -53,6 +81,8 @@ struct UpstreamState { last_check: std::time::Instant, /// Per-DC latency EMA (index 0 = DC1, index 4 = DC5) dc_latency: [LatencyEma; NUM_DCS], + /// Per-DC IP version preference (learned from connectivity tests) + dc_ip_pref: [IpPreference; NUM_DCS], } impl UpstreamState { @@ -63,16 +93,11 @@ impl UpstreamState { fails: 0, last_check: std::time::Instant::now(), dc_latency: [LatencyEma::new(0.3); NUM_DCS], + dc_ip_pref: [IpPreference::Unknown; NUM_DCS], } } - + /// Map DC index to latency array slot (0..NUM_DCS). - /// - /// Matches the C implementation's `mf_cluster_lookup` behavior: - /// - Standard DCs ±1..±5 → direct mapping to array index 0..4 - /// - Unknown DCs (CDN, media, etc.) → default DC slot (index 1 = DC 2) - /// This matches Telegram's `default 2;` in proxy-multi.conf. - /// - There is NO modular arithmetic in the C implementation. fn dc_array_idx(dc_idx: i16) -> Option { let abs_dc = dc_idx.unsigned_abs() as usize; if abs_dc == 0 { @@ -82,25 +107,22 @@ impl UpstreamState { Some(abs_dc - 1) } else { // Unknown DC → default cluster (DC 2, index 1) - // Same as C: mf_cluster_lookup returns default_cluster Some(1) } } - + /// Get latency for a specific DC, falling back to average across all known DCs fn effective_latency(&self, dc_idx: Option) -> Option { - // Try DC-specific latency first if let Some(di) = dc_idx.and_then(Self::dc_array_idx) { if let Some(ms) = self.dc_latency[di].get() { return Some(ms); } } - - // Fallback: average of all known DC latencies + let (sum, count) = self.dc_latency.iter() .filter_map(|l| l.get()) .fold((0.0, 0u32), |(s, c), v| (s + v, c + 1)); - + if count > 0 { Some(sum / count as f64) } else { None } } } @@ -114,11 +136,14 @@ pub struct DcPingResult { pub error: Option, } -/// Result of startup ping for one upstream +/// Result of startup ping for one upstream (separate v6/v4 results) #[derive(Debug, Clone)] pub struct StartupPingResult { - pub results: Vec, + pub v6_results: Vec, + pub v4_results: Vec, pub upstream_name: String, + /// True if both IPv6 and IPv4 have at least one working DC + pub both_available: bool, } // ============= Upstream Manager ============= @@ -134,22 +159,13 @@ impl UpstreamManager { .filter(|c| c.enabled) .map(UpstreamState::new) .collect(); - + Self { upstreams: Arc::new(RwLock::new(states)), } } - + /// Select upstream using latency-weighted random selection. - /// - /// `effective_weight = config_weight × latency_factor` - /// - /// where `latency_factor = 1000 / latency_ms` if latency is known, - /// or `1.0` if no latency data is available. - /// - /// This means a 50ms upstream gets factor 20, a 200ms upstream gets - /// factor 5 — the faster route is 4× more likely to be chosen - /// (all else being equal). async fn select_upstream(&self, dc_idx: Option) -> Option { let upstreams = self.upstreams.read().await; if upstreams.is_empty() { @@ -161,34 +177,32 @@ impl UpstreamManager { .filter(|(_, u)| u.healthy) .map(|(i, _)| i) .collect(); - + if healthy.is_empty() { - // All unhealthy — pick any return Some(rand::rng().gen_range(0..upstreams.len())); } - + if healthy.len() == 1 { return Some(healthy[0]); } - - // Calculate latency-weighted scores + let weights: Vec<(usize, f64)> = healthy.iter().map(|&i| { let base = upstreams[i].config.weight as f64; let latency_factor = upstreams[i].effective_latency(dc_idx) .map(|ms| if ms > 1.0 { 1000.0 / ms } else { 1000.0 }) .unwrap_or(1.0); - + (i, base * latency_factor) }).collect(); - + let total: f64 = weights.iter().map(|(_, w)| w).sum(); - + if total <= 0.0 { return Some(healthy[rand::rng().gen_range(0..healthy.len())]); } - + let mut choice: f64 = rand::rng().gen_range(0.0..total); - + for &(idx, weight) in &weights { if choice < weight { trace!( @@ -202,25 +216,22 @@ impl UpstreamManager { } choice -= weight; } - + Some(healthy[0]) } - + /// Connect to target through a selected upstream. - /// - /// `dc_idx` is used for latency-based upstream selection and RTT tracking. - /// Pass `None` if DC index is unknown. pub async fn connect(&self, target: SocketAddr, dc_idx: Option) -> Result { let idx = self.select_upstream(dc_idx).await .ok_or_else(|| ProxyError::Config("No upstreams available".to_string()))?; - + let upstream = { let guard = self.upstreams.read().await; guard[idx].config.clone() }; - + let start = Instant::now(); - + match self.connect_via_upstream(&upstream, target).await { Ok(stream) => { let rtt_ms = start.elapsed().as_secs_f64() * 1000.0; @@ -231,8 +242,7 @@ impl UpstreamManager { } u.healthy = true; u.fails = 0; - - // Store per-DC latency + if let Some(di) = dc_idx.and_then(UpstreamState::dc_array_idx) { u.dc_latency[di].update(rtt_ms); } @@ -253,92 +263,93 @@ impl UpstreamManager { } } } - + async fn connect_via_upstream(&self, config: &UpstreamConfig, target: SocketAddr) -> Result { match &config.upstream_type { UpstreamType::Direct { interface } => { let bind_ip = interface.as_ref() .and_then(|s| s.parse::().ok()); - + let socket = create_outgoing_socket_bound(target, bind_ip)?; - + socket.set_nonblocking(true)?; match socket.connect(&target.into()) { Ok(()) => {}, Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {}, Err(err) => return Err(ProxyError::Io(err)), } - + let std_stream: std::net::TcpStream = socket.into(); let stream = TcpStream::from_std(std_stream)?; - + stream.writable().await?; if let Some(e) = stream.take_error()? { return Err(ProxyError::Io(e)); } - + Ok(stream) }, UpstreamType::Socks4 { address, interface, user_id } => { let proxy_addr: SocketAddr = address.parse() .map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?; - + let bind_ip = interface.as_ref() .and_then(|s| s.parse::().ok()); - + let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?; - + socket.set_nonblocking(true)?; match socket.connect(&proxy_addr.into()) { Ok(()) => {}, Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {}, Err(err) => return Err(ProxyError::Io(err)), } - + let std_stream: std::net::TcpStream = socket.into(); let mut stream = TcpStream::from_std(std_stream)?; - + stream.writable().await?; if let Some(e) = stream.take_error()? { return Err(ProxyError::Io(e)); } - + connect_socks4(&mut stream, target, user_id.as_deref()).await?; Ok(stream) }, UpstreamType::Socks5 { address, interface, username, password } => { let proxy_addr: SocketAddr = address.parse() .map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?; - + let bind_ip = interface.as_ref() .and_then(|s| s.parse::().ok()); - + let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?; - + socket.set_nonblocking(true)?; match socket.connect(&proxy_addr.into()) { Ok(()) => {}, Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {}, Err(err) => return Err(ProxyError::Io(err)), } - + let std_stream: std::net::TcpStream = socket.into(); let mut stream = TcpStream::from_std(std_stream)?; - + stream.writable().await?; if let Some(e) = stream.take_error()? { return Err(ProxyError::Io(e)); } - + connect_socks5(&mut stream, target, username.as_deref(), password.as_deref()).await?; Ok(stream) }, } } - - // ============= Startup Ping ============= - + + // ============= Startup Ping (test both IPv6 and IPv4) ============= + /// Ping all Telegram DCs through all upstreams. + /// Tests BOTH IPv6 and IPv4, returns separate results for each. pub async fn ping_all_dcs(&self, prefer_ipv6: bool) -> Vec { let upstreams: Vec<(usize, UpstreamConfig)> = { let guard = self.upstreams.read().await; @@ -346,11 +357,9 @@ impl UpstreamManager { .map(|(i, u)| (i, u.config.clone())) .collect() }; - - let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 }; - + let mut all_results = Vec::new(); - + for (upstream_idx, upstream_config) in &upstreams { let upstream_name = match &upstream_config.upstream_type { UpstreamType::Direct { interface } => { @@ -359,130 +368,260 @@ impl UpstreamManager { UpstreamType::Socks4 { address, .. } => format!("socks4://{}", address), UpstreamType::Socks5 { address, .. } => format!("socks5://{}", address), }; - - let mut dc_results = Vec::new(); - - for (dc_zero_idx, dc_ip) in datacenters.iter().enumerate() { - let dc_addr = SocketAddr::new(*dc_ip, TG_DATACENTER_PORT); - - let ping_result = tokio::time::timeout( - Duration::from_secs(5), - self.ping_single_dc(upstream_config, dc_addr) + + let mut v6_results = Vec::new(); + let mut v4_results = Vec::new(); + + // === Ping IPv6 first === + for dc_zero_idx in 0..NUM_DCS { + let dc_v6 = TG_DATACENTERS_V6[dc_zero_idx]; + let addr_v6 = SocketAddr::new(dc_v6, TG_DATACENTER_PORT); + + let result = tokio::time::timeout( + Duration::from_secs(DC_PING_TIMEOUT_SECS), + self.ping_single_dc(&upstream_config, addr_v6) ).await; - - let result = match ping_result { + + let ping_result = match result { Ok(Ok(rtt_ms)) => { - // Store per-DC latency let mut guard = self.upstreams.write().await; if let Some(u) = guard.get_mut(*upstream_idx) { u.dc_latency[dc_zero_idx].update(rtt_ms); } DcPingResult { dc_idx: dc_zero_idx + 1, - dc_addr, + dc_addr: addr_v6, rtt_ms: Some(rtt_ms), error: None, } } Ok(Err(e)) => DcPingResult { dc_idx: dc_zero_idx + 1, - dc_addr, + dc_addr: addr_v6, rtt_ms: None, error: Some(e.to_string()), }, Err(_) => DcPingResult { dc_idx: dc_zero_idx + 1, - dc_addr, + dc_addr: addr_v6, rtt_ms: None, - error: Some("timeout (5s)".to_string()), + error: Some("timeout".to_string()), }, }; - - dc_results.push(result); + v6_results.push(ping_result); } - + + // === Then ping IPv4 === + for dc_zero_idx in 0..NUM_DCS { + let dc_v4 = TG_DATACENTERS_V4[dc_zero_idx]; + let addr_v4 = SocketAddr::new(dc_v4, TG_DATACENTER_PORT); + + let result = tokio::time::timeout( + Duration::from_secs(DC_PING_TIMEOUT_SECS), + self.ping_single_dc(&upstream_config, addr_v4) + ).await; + + let ping_result = match result { + Ok(Ok(rtt_ms)) => { + let mut guard = self.upstreams.write().await; + if let Some(u) = guard.get_mut(*upstream_idx) { + u.dc_latency[dc_zero_idx].update(rtt_ms); + } + DcPingResult { + dc_idx: dc_zero_idx + 1, + dc_addr: addr_v4, + rtt_ms: Some(rtt_ms), + error: None, + } + } + Ok(Err(e)) => DcPingResult { + dc_idx: dc_zero_idx + 1, + dc_addr: addr_v4, + rtt_ms: None, + error: Some(e.to_string()), + }, + Err(_) => DcPingResult { + dc_idx: dc_zero_idx + 1, + dc_addr: addr_v4, + rtt_ms: None, + error: Some("timeout".to_string()), + }, + }; + v4_results.push(ping_result); + } + + // Check if both IP versions have at least one working DC + let v6_has_working = v6_results.iter().any(|r| r.rtt_ms.is_some()); + let v4_has_working = v4_results.iter().any(|r| r.rtt_ms.is_some()); + let both_available = v6_has_working && v4_has_working; + + // Update IP preference for each DC + { + let mut guard = self.upstreams.write().await; + if let Some(u) = guard.get_mut(*upstream_idx) { + for dc_zero_idx in 0..NUM_DCS { + let v6_ok = v6_results[dc_zero_idx].rtt_ms.is_some(); + let v4_ok = v4_results[dc_zero_idx].rtt_ms.is_some(); + + u.dc_ip_pref[dc_zero_idx] = match (v6_ok, v4_ok) { + (true, true) => IpPreference::BothWork, + (true, false) => IpPreference::PreferV6, + (false, true) => IpPreference::PreferV4, + (false, false) => IpPreference::Unavailable, + }; + } + } + } + all_results.push(StartupPingResult { - results: dc_results, + v6_results, + v4_results, upstream_name, + both_available, }); } - + all_results } - + async fn ping_single_dc(&self, config: &UpstreamConfig, target: SocketAddr) -> Result { let start = Instant::now(); let _stream = self.connect_via_upstream(config, target).await?; Ok(start.elapsed().as_secs_f64() * 1000.0) } - + // ============= Health Checks ============= - + /// Background health check: rotates through DCs, 30s interval. + /// Uses preferred IP version based on config. pub async fn run_health_checks(&self, prefer_ipv6: bool) { - let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 }; let mut dc_rotation = 0usize; - + loop { tokio::time::sleep(Duration::from_secs(30)).await; - - let dc_zero_idx = dc_rotation % datacenters.len(); + + let dc_zero_idx = dc_rotation % NUM_DCS; dc_rotation += 1; - - let check_target = SocketAddr::new(datacenters[dc_zero_idx], TG_DATACENTER_PORT); - + + let dc_addr = if prefer_ipv6 { + SocketAddr::new(TG_DATACENTERS_V6[dc_zero_idx], TG_DATACENTER_PORT) + } else { + SocketAddr::new(TG_DATACENTERS_V4[dc_zero_idx], TG_DATACENTER_PORT) + }; + + let fallback_addr = if prefer_ipv6 { + SocketAddr::new(TG_DATACENTERS_V4[dc_zero_idx], TG_DATACENTER_PORT) + } else { + SocketAddr::new(TG_DATACENTERS_V6[dc_zero_idx], TG_DATACENTER_PORT) + }; + let count = self.upstreams.read().await.len(); + for i in 0..count { let config = { let guard = self.upstreams.read().await; guard[i].config.clone() }; - + let start = Instant::now(); let result = tokio::time::timeout( Duration::from_secs(10), - self.connect_via_upstream(&config, check_target) + self.connect_via_upstream(&config, dc_addr) ).await; - - let mut guard = self.upstreams.write().await; - let u = &mut guard[i]; - + match result { Ok(Ok(_stream)) => { let rtt_ms = start.elapsed().as_secs_f64() * 1000.0; + let mut guard = self.upstreams.write().await; + let u = &mut guard[i]; u.dc_latency[dc_zero_idx].update(rtt_ms); - + if !u.healthy { info!( - rtt = format!("{:.0}ms", rtt_ms), + rtt = format!("{:.0} ms", rtt_ms), dc = dc_zero_idx + 1, "Upstream recovered" ); } u.healthy = true; u.fails = 0; + u.last_check = std::time::Instant::now(); } - Ok(Err(e)) => { - u.fails += 1; - debug!(dc = dc_zero_idx + 1, fails = u.fails, - "Health check failed: {}", e); - if u.fails > 3 { - u.healthy = false; - warn!("Upstream unhealthy (fails)"); - } - } - Err(_) => { - u.fails += 1; - debug!(dc = dc_zero_idx + 1, fails = u.fails, - "Health check timeout"); - if u.fails > 3 { - u.healthy = false; - warn!("Upstream unhealthy (timeout)"); + Ok(Err(_)) | Err(_) => { + // Try fallback + debug!(dc = dc_zero_idx + 1, "Health check failed, trying fallback"); + + let start2 = Instant::now(); + let result2 = tokio::time::timeout( + Duration::from_secs(10), + self.connect_via_upstream(&config, fallback_addr) + ).await; + + let mut guard = self.upstreams.write().await; + let u = &mut guard[i]; + + match result2 { + Ok(Ok(_stream)) => { + let rtt_ms = start2.elapsed().as_secs_f64() * 1000.0; + u.dc_latency[dc_zero_idx].update(rtt_ms); + + if !u.healthy { + info!( + rtt = format!("{:.0} ms", rtt_ms), + dc = dc_zero_idx + 1, + "Upstream recovered (fallback)" + ); + } + u.healthy = true; + u.fails = 0; + } + Ok(Err(e)) => { + u.fails += 1; + debug!(dc = dc_zero_idx + 1, fails = u.fails, + "Health check failed (both): {}", e); + if u.fails > 3 { + u.healthy = false; + warn!("Upstream unhealthy (fails)"); + } + } + Err(_) => { + u.fails += 1; + debug!(dc = dc_zero_idx + 1, fails = u.fails, + "Health check timeout (both)"); + if u.fails > 3 { + u.healthy = false; + warn!("Upstream unhealthy (timeout)"); + } + } } + u.last_check = std::time::Instant::now(); } } - u.last_check = std::time::Instant::now(); } } } + + /// Get the preferred IP for a DC (for use by other components) + pub async fn get_dc_ip_preference(&self, dc_idx: i16) -> Option { + let guard = self.upstreams.read().await; + if guard.is_empty() { + return None; + } + + UpstreamState::dc_array_idx(dc_idx) + .map(|idx| guard[0].dc_ip_pref[idx]) + } + + /// Get preferred DC address based on config preference + pub async fn get_dc_addr(&self, dc_idx: i16, prefer_ipv6: bool) -> Option { + let arr_idx = UpstreamState::dc_array_idx(dc_idx)?; + + let ip = if prefer_ipv6 { + TG_DATACENTERS_V6[arr_idx] + } else { + TG_DATACENTERS_V4[arr_idx] + }; + + Some(SocketAddr::new(ip, TG_DATACENTER_PORT)) + } } \ No newline at end of file diff --git a/telemt b/telemt new file mode 100644 index 0000000..9db056f Binary files /dev/null and b/telemt differ diff --git a/tools/dc.py b/tools/dc.py new file mode 100644 index 0000000..f142baf --- /dev/null +++ b/tools/dc.py @@ -0,0 +1,121 @@ +from telethon import TelegramClient +from telethon.tl.functions.help import GetConfigRequest +import asyncio + +api_id = '' +api_hash = '' + +async def get_all_servers(): + print("🔄 Подключаемся к Telegram...") + client = TelegramClient('session', api_id, api_hash) + + await client.start() + print("✅ Подключение установлено!\n") + + print("📡 Запрашиваем конфигурацию серверов...") + config = await client(GetConfigRequest()) + + print(f"📊 Получено серверов: {len(config.dc_options)}\n") + print("="*80) + + # Группируем серверы по DC ID + dc_groups = {} + for dc in config.dc_options: + if dc.id not in dc_groups: + dc_groups[dc.id] = [] + dc_groups[dc.id].append(dc) + + # Выводим все серверы, сгруппированные по DC + for dc_id in sorted(dc_groups.keys()): + servers = dc_groups[dc_id] + print(f"\n🌐 DATACENTER {dc_id} ({len(servers)} серверов)") + print("-" * 80) + + for dc in servers: + # Собираем флаги + flags = [] + if dc.ipv6: + flags.append("IPv6") + if dc.media_only: + flags.append("🎬 MEDIA-ONLY") + if dc.cdn: + flags.append("📦 CDN") + if dc.tcpo_only: + flags.append("🔒 TCPO") + if dc.static: + flags.append("📌 STATIC") + + flags_str = f" [{', '.join(flags)}]" if flags else " [STANDARD]" + + # Форматируем IP (выравниваем для читаемости) + ip_display = f"{dc.ip_address:45}" + + print(f" {ip_display}:{dc.port:5}{flags_str}") + + # Статистика + print("\n" + "="*80) + print("📈 СТАТИСТИКА:") + print("="*80) + + total = len(config.dc_options) + ipv4_count = sum(1 for dc in config.dc_options if not dc.ipv6) + ipv6_count = sum(1 for dc in config.dc_options if dc.ipv6) + media_count = sum(1 for dc in config.dc_options if dc.media_only) + cdn_count = sum(1 for dc in config.dc_options if dc.cdn) + tcpo_count = sum(1 for dc in config.dc_options if dc.tcpo_only) + static_count = sum(1 for dc in config.dc_options if dc.static) + + print(f" Всего серверов: {total}") + print(f" IPv4 серверы: {ipv4_count}") + print(f" IPv6 серверы: {ipv6_count}") + print(f" Media-only: {media_count}") + print(f" CDN серверы: {cdn_count}") + print(f" TCPO-only: {tcpo_count}") + print(f" Static: {static_count}") + + # Дополнительная информация из config + print("\n" + "="*80) + print("ℹ️ ДОПОЛНИТЕЛЬНАЯ ИНФОРМАЦИЯ:") + print("="*80) + print(f" Дата конфигурации: {config.date}") + print(f" Expires: {config.expires}") + print(f" Test mode: {config.test_mode}") + print(f" This DC: {config.this_dc}") + + # Сохраняем в файл + print("\n💾 Сохраняем результаты в файл telegram_servers.txt...") + with open('telegram_servers.txt', 'w', encoding='utf-8') as f: + f.write("TELEGRAM DATACENTER SERVERS\n") + f.write("="*80 + "\n\n") + + for dc_id in sorted(dc_groups.keys()): + servers = dc_groups[dc_id] + f.write(f"\nDATACENTER {dc_id} ({len(servers)} servers)\n") + f.write("-" * 80 + "\n") + + for dc in servers: + flags = [] + if dc.ipv6: + flags.append("IPv6") + if dc.media_only: + flags.append("MEDIA-ONLY") + if dc.cdn: + flags.append("CDN") + if dc.tcpo_only: + flags.append("TCPO") + if dc.static: + flags.append("STATIC") + + flags_str = f" [{', '.join(flags)}]" if flags else " [STANDARD]" + f.write(f" {dc.ip_address}:{dc.port}{flags_str}\n") + + f.write(f"\n\nTotal servers: {total}\n") + f.write(f"Generated: {config.date}\n") + + print("✅ Результаты сохранены в telegram_servers.txt") + + await client.disconnect() + print("\n👋 Отключились от Telegram") + +if __name__ == '__main__': + asyncio.run(get_all_servers()) \ No newline at end of file