use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use std::sync::OnceLock; use std::sync::atomic::AtomicU64; use std::time::Duration; use bytes::BytesMut; use rand::Rng; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; use tokio::sync::{Mutex, RwLock}; use tokio::time::{Instant, timeout}; use tracing::{debug, info, warn}; use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_keys, sha256}; use crate::error::{ProxyError, Result}; use crate::protocol::constants::*; use super::ConnRegistry; use super::codec::{ RpcWriter, build_handshake_payload, build_nonce_payload, build_rpc_frame, cbc_decrypt_inplace, cbc_encrypt_padded, parse_nonce_payload, read_rpc_frame_plaintext, }; use super::reader::reader_loop; use super::wire::{IpMaterial, extract_ip_material}; const ME_ACTIVE_PING_SECS: u64 = 25; const ME_ACTIVE_PING_JITTER_SECS: i64 = 5; pub struct MePool { pub(super) registry: Arc, pub(super) writers: Arc>)>>> , pub(super) rr: AtomicU64, pub(super) proxy_tag: Option>, proxy_secret: Vec, pub(super) nat_ip_cfg: Option, pub(super) nat_ip_detected: OnceLock, pub(super) nat_probe: bool, pub(super) nat_stun: Option, pool_size: usize, } impl MePool { pub fn new( proxy_tag: Option>, proxy_secret: Vec, nat_ip: Option, nat_probe: bool, nat_stun: Option, ) -> Arc { Arc::new(Self { registry: Arc::new(ConnRegistry::new()), writers: Arc::new(RwLock::new(Vec::new())), rr: AtomicU64::new(0), proxy_tag, proxy_secret, nat_ip_cfg: nat_ip, nat_ip_detected: OnceLock::new(), nat_probe, nat_stun, pool_size: 2, }) } pub fn has_proxy_tag(&self) -> bool { self.proxy_tag.is_some() } pub fn translate_our_addr(&self, addr: SocketAddr) -> SocketAddr { let ip = self.translate_ip_for_nat(addr.ip()); SocketAddr::new(ip, addr.port()) } pub fn registry(&self) -> &Arc { &self.registry } fn writers_arc(&self) -> Arc>)>>> { self.writers.clone() } fn key_selector(&self) -> u32 { if self.proxy_secret.len() >= 4 { u32::from_le_bytes([ self.proxy_secret[0], self.proxy_secret[1], self.proxy_secret[2], self.proxy_secret[3], ]) } else { 0 } } pub async fn init(self: &Arc, pool_size: usize, rng: &SecureRandom) -> Result<()> { let addrs = &*TG_MIDDLE_PROXIES_FLAT_V4; let ks = self.key_selector(); info!( me_servers = addrs.len(), pool_size, key_selector = format_args!("0x{ks:08x}"), secret_len = self.proxy_secret.len(), "Initializing ME pool" ); for &(ip, port) in addrs.iter() { for i in 0..pool_size { let addr = SocketAddr::new(ip, port); match self.connect_one(addr, rng).await { Ok(()) => info!(%addr, idx = i, "ME connected"), Err(e) => warn!(%addr, idx = i, error = %e, "ME connect failed"), } } if self.writers.read().await.len() >= pool_size { break; } } if self.writers.read().await.is_empty() { return Err(ProxyError::Proxy("No ME connections".into())); } Ok(()) } pub(crate) async fn connect_one( self: &Arc, addr: SocketAddr, rng: &SecureRandom, ) -> Result<()> { let secret = &self.proxy_secret; if secret.len() < 32 { return Err(ProxyError::Proxy( "proxy-secret too short for ME auth".into(), )); } let stream = timeout( Duration::from_secs(ME_CONNECT_TIMEOUT_SECS), TcpStream::connect(addr), ) .await .map_err(|_| ProxyError::ConnectionTimeout { addr: addr.to_string(), })? .map_err(ProxyError::Io)?; stream.set_nodelay(true).ok(); let local_addr = stream.local_addr().map_err(ProxyError::Io)?; let peer_addr = stream.peer_addr().map_err(ProxyError::Io)?; let _ = self.maybe_detect_nat_ip(local_addr.ip()).await; let reflected = if self.nat_probe { self.maybe_reflect_public_addr().await } else { None }; let local_addr_nat = self.translate_our_addr_with_reflection(local_addr, reflected); let peer_addr_nat = SocketAddr::new(self.translate_ip_for_nat(peer_addr.ip()), peer_addr.port()); let (mut rd, mut wr) = tokio::io::split(stream); let my_nonce: [u8; 16] = rng.bytes(16).try_into().unwrap(); let crypto_ts = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs() as u32; let ks = self.key_selector(); let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce); let nonce_frame = build_rpc_frame(-2, &nonce_payload); let dump = hex_dump(&nonce_frame[..nonce_frame.len().min(44)]); info!( key_selector = format_args!("0x{ks:08x}"), crypto_ts, frame_len = nonce_frame.len(), nonce_frame_hex = %dump, "Sending ME nonce frame" ); wr.write_all(&nonce_frame).await.map_err(ProxyError::Io)?; wr.flush().await.map_err(ProxyError::Io)?; let (srv_seq, srv_nonce_payload) = timeout( Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS), read_rpc_frame_plaintext(&mut rd), ) .await .map_err(|_| ProxyError::TgHandshakeTimeout)??; if srv_seq != -2 { return Err(ProxyError::InvalidHandshake(format!( "Expected seq=-2, got {srv_seq}" ))); } let (srv_key_select, schema, srv_ts, srv_nonce) = parse_nonce_payload(&srv_nonce_payload)?; if schema != RPC_CRYPTO_AES_U32 { warn!(schema = format_args!("0x{schema:08x}"), "Unsupported ME crypto schema"); return Err(ProxyError::InvalidHandshake(format!( "Unsupported crypto schema: 0x{schema:x}" ))); } if srv_key_select != ks { return Err(ProxyError::InvalidHandshake(format!( "Server key_select 0x{srv_key_select:08x} != client 0x{ks:08x}" ))); } let skew = crypto_ts.abs_diff(srv_ts); if skew > 30 { return Err(ProxyError::InvalidHandshake(format!( "nonce crypto_ts skew too large: client={crypto_ts}, server={srv_ts}, skew={skew}s" ))); } info!( %local_addr, %local_addr_nat, reflected_ip = reflected.map(|r| r.ip()).as_ref().map(ToString::to_string), %peer_addr, %peer_addr_nat, key_selector = format_args!("0x{ks:08x}"), crypto_schema = format_args!("0x{schema:08x}"), skew_secs = skew, "ME key derivation parameters" ); let ts_bytes = crypto_ts.to_le_bytes(); let server_port_bytes = peer_addr_nat.port().to_le_bytes(); let client_port_bytes = local_addr_nat.port().to_le_bytes(); let server_ip = extract_ip_material(peer_addr_nat); let client_ip = extract_ip_material(local_addr_nat); let (srv_ip_opt, clt_ip_opt, clt_v6_opt, srv_v6_opt, hs_our_ip, hs_peer_ip) = match (server_ip, client_ip) { (IpMaterial::V4(srv), IpMaterial::V4(clt)) => { (Some(srv), Some(clt), None, None, clt, srv) } (IpMaterial::V6(srv), IpMaterial::V6(clt)) => { let zero = [0u8; 4]; (None, None, Some(clt), Some(srv), zero, zero) } _ => { return Err(ProxyError::InvalidHandshake( "mixed IPv4/IPv6 endpoints are not supported for ME key derivation" .to_string(), )); } }; let diag_level: u8 = std::env::var("ME_DIAG") .ok() .and_then(|v| v.parse().ok()) .unwrap_or(0); let prekey_client = build_middleproxy_prekey( &srv_nonce, &my_nonce, &ts_bytes, srv_ip_opt.as_ref().map(|x| &x[..]), &client_port_bytes, b"CLIENT", clt_ip_opt.as_ref().map(|x| &x[..]), &server_port_bytes, secret, clt_v6_opt.as_ref(), srv_v6_opt.as_ref(), ); let prekey_server = build_middleproxy_prekey( &srv_nonce, &my_nonce, &ts_bytes, srv_ip_opt.as_ref().map(|x| &x[..]), &client_port_bytes, b"SERVER", clt_ip_opt.as_ref().map(|x| &x[..]), &server_port_bytes, secret, clt_v6_opt.as_ref(), srv_v6_opt.as_ref(), ); let (wk, wi) = derive_middleproxy_keys( &srv_nonce, &my_nonce, &ts_bytes, srv_ip_opt.as_ref().map(|x| &x[..]), &client_port_bytes, b"CLIENT", clt_ip_opt.as_ref().map(|x| &x[..]), &server_port_bytes, secret, clt_v6_opt.as_ref(), srv_v6_opt.as_ref(), ); let (rk, ri) = derive_middleproxy_keys( &srv_nonce, &my_nonce, &ts_bytes, srv_ip_opt.as_ref().map(|x| &x[..]), &client_port_bytes, b"SERVER", clt_ip_opt.as_ref().map(|x| &x[..]), &server_port_bytes, secret, clt_v6_opt.as_ref(), srv_v6_opt.as_ref(), ); let hs_payload = build_handshake_payload(hs_our_ip, local_addr.port(), hs_peer_ip, peer_addr.port()); let hs_frame = build_rpc_frame(-1, &hs_payload); if diag_level >= 1 { info!( write_key = %hex_dump(&wk), write_iv = %hex_dump(&wi), read_key = %hex_dump(&rk), read_iv = %hex_dump(&ri), srv_ip = %srv_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(), clt_ip = %clt_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(), srv_port = %hex_dump(&server_port_bytes), clt_port = %hex_dump(&client_port_bytes), crypto_ts = %hex_dump(&ts_bytes), nonce_srv = %hex_dump(&srv_nonce), nonce_clt = %hex_dump(&my_nonce), prekey_sha256_client = %hex_dump(&sha256(&prekey_client)), prekey_sha256_server = %hex_dump(&sha256(&prekey_server)), hs_plain = %hex_dump(&hs_frame), proxy_secret_sha256 = %hex_dump(&sha256(secret)), "ME diag: derived keys and handshake plaintext" ); } if diag_level >= 2 { info!( prekey_client = %hex_dump(&prekey_client), prekey_server = %hex_dump(&prekey_server), "ME diag: full prekey buffers" ); } let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?; if diag_level >= 1 { info!( hs_cipher = %hex_dump(&encrypted_hs), "ME diag: handshake ciphertext" ); } wr.write_all(&encrypted_hs).await.map_err(ProxyError::Io)?; wr.flush().await.map_err(ProxyError::Io)?; let deadline = Instant::now() + Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS); let mut enc_buf = BytesMut::with_capacity(256); let mut dec_buf = BytesMut::with_capacity(256); let mut read_iv = ri; let mut handshake_ok = false; while Instant::now() < deadline && !handshake_ok { let remaining = deadline - Instant::now(); let mut tmp = [0u8; 256]; let n = match timeout(remaining, rd.read(&mut tmp)).await { Ok(Ok(0)) => { return Err(ProxyError::Io(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, "ME closed during handshake", ))); } Ok(Ok(n)) => n, Ok(Err(e)) => return Err(ProxyError::Io(e)), Err(_) => return Err(ProxyError::TgHandshakeTimeout), }; enc_buf.extend_from_slice(&tmp[..n]); let blocks = enc_buf.len() / 16 * 16; if blocks > 0 { let mut chunk = vec![0u8; blocks]; chunk.copy_from_slice(&enc_buf[..blocks]); read_iv = cbc_decrypt_inplace(&rk, &read_iv, &mut chunk)?; dec_buf.extend_from_slice(&chunk); let _ = enc_buf.split_to(blocks); } while dec_buf.len() >= 4 { let fl = u32::from_le_bytes(dec_buf[0..4].try_into().unwrap()) as usize; if fl == 4 { let _ = dec_buf.split_to(4); continue; } if !(12..=(1 << 24)).contains(&fl) { return Err(ProxyError::InvalidHandshake(format!( "Bad HS response frame len: {fl}" ))); } if dec_buf.len() < fl { break; } let frame = dec_buf.split_to(fl); let pe = fl - 4; let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); let ac = crate::crypto::crc32(&frame[..pe]); if ec != ac { return Err(ProxyError::InvalidHandshake(format!( "HS CRC mismatch: 0x{ec:08x} vs 0x{ac:08x}" ))); } let hs_type = u32::from_le_bytes(frame[8..12].try_into().unwrap()); if hs_type == RPC_HANDSHAKE_ERROR_U32 { let err_code = if frame.len() >= 16 { i32::from_le_bytes(frame[12..16].try_into().unwrap()) } else { -1 }; return Err(ProxyError::InvalidHandshake(format!( "ME rejected handshake (error={err_code})" ))); } if hs_type != RPC_HANDSHAKE_U32 { return Err(ProxyError::InvalidHandshake(format!( "Expected HANDSHAKE 0x{RPC_HANDSHAKE_U32:08x}, got 0x{hs_type:08x}" ))); } handshake_ok = true; break; } } if !handshake_ok { return Err(ProxyError::TgHandshakeTimeout); } info!(%addr, "RPC handshake OK"); let rpc_w = Arc::new(Mutex::new(RpcWriter { writer: wr, key: wk, iv: write_iv, seq_no: 0, })); self.writers.write().await.push((addr, rpc_w.clone())); let reg = self.registry.clone(); let w_pong = rpc_w.clone(); let w_pool = self.writers_arc(); let w_ping = rpc_w.clone(); let w_pool_ping = self.writers_arc(); tokio::spawn(async move { if let Err(e) = reader_loop(rd, rk, read_iv, reg, enc_buf, dec_buf, w_pong.clone()).await { warn!(error = %e, "ME reader ended"); } let mut ws = w_pool.write().await; ws.retain(|(_, w)| !Arc::ptr_eq(w, &w_pong)); info!(remaining = ws.len(), "Dead ME writer removed from pool"); }); tokio::spawn(async move { let mut ping_id: i64 = rand::random::(); loop { let jitter = rand::rng() .random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; tokio::time::sleep(Duration::from_secs(wait)).await; let mut p = Vec::with_capacity(12); p.extend_from_slice(&RPC_PING_U32.to_le_bytes()); p.extend_from_slice(&ping_id.to_le_bytes()); ping_id = ping_id.wrapping_add(1); if let Err(e) = w_ping.lock().await.send(&p).await { debug!(error = %e, "Active ME ping failed, removing dead writer"); let mut ws = w_pool_ping.write().await; ws.retain(|(_, w)| !Arc::ptr_eq(w, &w_ping)); break; } } }); Ok(()) } } fn hex_dump(data: &[u8]) -> String { const MAX: usize = 64; let mut out = String::with_capacity(data.len() * 2 + 3); for (i, b) in data.iter().take(MAX).enumerate() { if i > 0 { out.push(' '); } out.push_str(&format!("{b:02x}")); } if data.len() > MAX { out.push_str(" …"); } out }