use std::collections::HashMap; use std::net::{IpAddr, Ipv6Addr, SocketAddr}; use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, Ordering}; use bytes::BytesMut; use rand::Rng; use rand::seq::SliceRandom; use tokio::sync::{Mutex, RwLock}; use tokio_util::sync::CancellationToken; use tracing::{debug, info, warn}; use std::time::Duration; use crate::crypto::SecureRandom; use crate::error::{ProxyError, Result}; use crate::network::probe::NetworkDecision; use crate::network::IpFamily; use crate::protocol::constants::*; use super::ConnRegistry; use super::registry::{BoundConn, ConnMeta}; use super::codec::RpcWriter; use super::reader::reader_loop; use super::MeResponse; const ME_ACTIVE_PING_SECS: u64 = 25; const ME_ACTIVE_PING_JITTER_SECS: i64 = 5; #[derive(Clone)] pub struct MeWriter { pub id: u64, pub addr: SocketAddr, pub writer: Arc>, pub cancel: CancellationToken, pub degraded: Arc, pub draining: Arc, } pub struct MePool { pub(super) registry: Arc, pub(super) writers: Arc>>, pub(super) rr: AtomicU64, pub(super) decision: NetworkDecision, pub(super) rng: Arc, pub(super) proxy_tag: Option>, pub(super) proxy_secret: Arc>>, pub(super) nat_ip_cfg: Option, pub(super) nat_ip_detected: Arc>>, pub(super) nat_probe: bool, pub(super) nat_stun: Option, pub(super) detected_ipv6: Option, pub(super) nat_probe_attempts: std::sync::atomic::AtomicU8, pub(super) nat_probe_disabled: std::sync::atomic::AtomicBool, pub(super) me_one_retry: u8, pub(super) me_one_timeout: Duration, pub(super) proxy_map_v4: Arc>>>, pub(super) proxy_map_v6: Arc>>>, pub(super) default_dc: AtomicI32, pub(super) next_writer_id: AtomicU64, pub(super) ping_tracker: Arc>>, pub(super) rtt_stats: Arc>>, pub(super) nat_reflection_cache: Arc>, pool_size: usize, } #[derive(Debug, Default)] pub struct NatReflectionCache { pub v4: Option<(std::time::Instant, std::net::SocketAddr)>, pub v6: Option<(std::time::Instant, std::net::SocketAddr)>, } impl MePool { pub fn new( proxy_tag: Option>, proxy_secret: Vec, nat_ip: Option, nat_probe: bool, nat_stun: Option, detected_ipv6: Option, me_one_retry: u8, me_one_timeout_ms: u64, proxy_map_v4: HashMap>, proxy_map_v6: HashMap>, default_dc: Option, decision: NetworkDecision, rng: Arc, ) -> Arc { Arc::new(Self { registry: Arc::new(ConnRegistry::new()), writers: Arc::new(RwLock::new(Vec::new())), rr: AtomicU64::new(0), decision, rng, proxy_tag, proxy_secret: Arc::new(RwLock::new(proxy_secret)), nat_ip_cfg: nat_ip, nat_ip_detected: Arc::new(RwLock::new(None)), nat_probe, nat_stun, detected_ipv6, nat_probe_attempts: std::sync::atomic::AtomicU8::new(0), nat_probe_disabled: std::sync::atomic::AtomicBool::new(false), me_one_retry, me_one_timeout: Duration::from_millis(me_one_timeout_ms), pool_size: 2, proxy_map_v4: Arc::new(RwLock::new(proxy_map_v4)), proxy_map_v6: Arc::new(RwLock::new(proxy_map_v6)), default_dc: AtomicI32::new(default_dc.unwrap_or(0)), next_writer_id: AtomicU64::new(1), ping_tracker: Arc::new(Mutex::new(HashMap::new())), rtt_stats: Arc::new(Mutex::new(HashMap::new())), nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())), }) } pub fn has_proxy_tag(&self) -> bool { self.proxy_tag.is_some() } pub fn translate_our_addr(&self, addr: SocketAddr) -> SocketAddr { let ip = self.translate_ip_for_nat(addr.ip()); SocketAddr::new(ip, addr.port()) } pub fn registry(&self) -> &Arc { &self.registry } fn writers_arc(&self) -> Arc>> { self.writers.clone() } pub async fn reconcile_connections(self: &Arc, rng: &SecureRandom) { use std::collections::HashSet; let writers = self.writers.read().await; let current: HashSet = writers.iter().map(|w| w.addr).collect(); drop(writers); for family in self.family_order() { let map = self.proxy_map_for_family(family).await; for (_dc, addrs) in map.iter() { let dc_addrs: Vec = addrs .iter() .map(|(ip, port)| SocketAddr::new(*ip, *port)) .collect(); if !dc_addrs.iter().any(|a| current.contains(a)) { let mut shuffled = dc_addrs.clone(); shuffled.shuffle(&mut rand::rng()); for addr in shuffled { if self.connect_one(addr, rng).await.is_ok() { break; } } } } if !self.decision.effective_multipath && !current.is_empty() { break; } } } pub async fn update_proxy_maps( &self, new_v4: HashMap>, new_v6: Option>>, ) -> bool { let mut changed = false; { let mut guard = self.proxy_map_v4.write().await; if !new_v4.is_empty() && *guard != new_v4 { *guard = new_v4; changed = true; } } if let Some(v6) = new_v6 { let mut guard = self.proxy_map_v6.write().await; if !v6.is_empty() && *guard != v6 { *guard = v6; } } changed } pub async fn update_secret(&self, new_secret: Vec) -> bool { if new_secret.len() < 32 { warn!(len = new_secret.len(), "proxy-secret update ignored (too short)"); return false; } let mut guard = self.proxy_secret.write().await; if *guard != new_secret { *guard = new_secret; drop(guard); self.reconnect_all().await; return true; } false } pub async fn reconnect_all(&self) { // Graceful: do not drop all at once. New connections will use updated secret. // Existing writers remain until health monitor replaces them. // No-op here to avoid total outage. } pub(super) async fn key_selector(&self) -> u32 { let secret = self.proxy_secret.read().await; if secret.len() >= 4 { u32::from_le_bytes([secret[0], secret[1], secret[2], secret[3]]) } else { 0 } } pub(super) fn family_order(&self) -> Vec { let mut order = Vec::new(); if self.decision.prefer_ipv6() { if self.decision.ipv6_me { order.push(IpFamily::V6); } if self.decision.ipv4_me { order.push(IpFamily::V4); } } else { if self.decision.ipv4_me { order.push(IpFamily::V4); } if self.decision.ipv6_me { order.push(IpFamily::V6); } } order } async fn proxy_map_for_family(&self, family: IpFamily) -> HashMap> { match family { IpFamily::V4 => self.proxy_map_v4.read().await.clone(), IpFamily::V6 => self.proxy_map_v6.read().await.clone(), } } pub async fn init(self: &Arc, pool_size: usize, rng: &Arc) -> Result<()> { let family_order = self.family_order(); let ks = self.key_selector().await; info!( me_servers = self.proxy_map_v4.read().await.len(), pool_size, key_selector = format_args!("0x{ks:08x}"), secret_len = self.proxy_secret.read().await.len(), "Initializing ME pool" ); for family in family_order { let map = self.proxy_map_for_family(family).await; let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map .iter() .map(|(dc, addrs)| (*dc, addrs.clone())) .collect(); // Ensure at least one connection per DC; run DCs in parallel. let mut join = tokio::task::JoinSet::new(); let mut dc_failures = 0usize; for (dc, addrs) in dc_addrs.iter().cloned() { if addrs.is_empty() { continue; } let pool = Arc::clone(self); let rng_clone = Arc::clone(rng); join.spawn(async move { pool.connect_primary_for_dc(dc, addrs, rng_clone).await }); } while let Some(res) = join.join_next().await { if let Ok(false) = res { dc_failures += 1; } } if dc_failures > 2 { return Err(ProxyError::Proxy("Too many ME DC init failures, falling back to direct".into())); } // Additional connections up to pool_size total (round-robin across DCs) for (dc, addrs) in dc_addrs.iter() { for (ip, port) in addrs { if self.connection_count() >= pool_size { break; } let addr = SocketAddr::new(*ip, *port); if let Err(e) = self.connect_one(addr, rng.as_ref()).await { debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed"); } } if self.connection_count() >= pool_size { break; } } if !self.decision.effective_multipath && self.connection_count() > 0 { break; } } if self.writers.read().await.is_empty() { return Err(ProxyError::Proxy("No ME connections".into())); } Ok(()) } pub(crate) async fn connect_one(self: &Arc, addr: SocketAddr, rng: &SecureRandom) -> Result<()> { let secret_len = self.proxy_secret.read().await.len(); if secret_len < 32 { return Err(ProxyError::Proxy("proxy-secret too short for ME auth".into())); } let (stream, _connect_ms) = self.connect_tcp(addr).await?; let hs = self.handshake_only(stream, addr, rng).await?; let writer_id = self.next_writer_id.fetch_add(1, Ordering::Relaxed); let cancel = CancellationToken::new(); let degraded = Arc::new(AtomicBool::new(false)); let draining = Arc::new(AtomicBool::new(false)); let rpc_w = Arc::new(Mutex::new(RpcWriter { writer: hs.wr, key: hs.write_key, iv: hs.write_iv, seq_no: 0, })); let writer = MeWriter { id: writer_id, addr, writer: rpc_w.clone(), cancel: cancel.clone(), degraded: degraded.clone(), draining: draining.clone(), }; self.writers.write().await.push(writer.clone()); let reg = self.registry.clone(); let writers_arc = self.writers_arc(); let ping_tracker = self.ping_tracker.clone(); let rtt_stats = self.rtt_stats.clone(); let pool = Arc::downgrade(self); let cancel_ping = cancel.clone(); let rpc_w_ping = rpc_w.clone(); let ping_tracker_ping = ping_tracker.clone(); tokio::spawn(async move { let cancel_reader = cancel.clone(); let res = reader_loop( hs.rd, hs.read_key, hs.read_iv, reg.clone(), BytesMut::new(), BytesMut::new(), rpc_w.clone(), ping_tracker.clone(), rtt_stats.clone(), writer_id, degraded.clone(), cancel_reader.clone(), ) .await; if let Some(pool) = pool.upgrade() { pool.remove_writer_and_close_clients(writer_id).await; } if let Err(e) = res { warn!(error = %e, "ME reader ended"); } let mut ws = writers_arc.write().await; ws.retain(|w| w.id != writer_id); info!(remaining = ws.len(), "Dead ME writer removed from pool"); }); let pool_ping = Arc::downgrade(self); tokio::spawn(async move { let mut ping_id: i64 = rand::random::(); loop { let jitter = rand::rng() .random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; tokio::select! { _ = cancel_ping.cancelled() => { break; } _ = tokio::time::sleep(Duration::from_secs(wait)) => {} } let sent_id = ping_id; let mut p = Vec::with_capacity(12); p.extend_from_slice(&RPC_PING_U32.to_le_bytes()); p.extend_from_slice(&sent_id.to_le_bytes()); { let mut tracker = ping_tracker_ping.lock().await; tracker.insert(sent_id, (std::time::Instant::now(), writer_id)); } ping_id = ping_id.wrapping_add(1); if let Err(e) = rpc_w_ping.lock().await.send_and_flush(&p).await { debug!(error = %e, "Active ME ping failed, removing dead writer"); cancel_ping.cancel(); if let Some(pool) = pool_ping.upgrade() { pool.remove_writer_and_close_clients(writer_id).await; } break; } } }); Ok(()) } async fn connect_primary_for_dc( self: Arc, dc: i32, mut addrs: Vec<(IpAddr, u16)>, rng: Arc, ) -> bool { if addrs.is_empty() { return false; } addrs.shuffle(&mut rand::rng()); for (ip, port) in addrs { let addr = SocketAddr::new(ip, port); match self.connect_one(addr, rng.as_ref()).await { Ok(()) => { info!(%addr, dc = %dc, "ME connected"); return true; } Err(e) => warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next"), } } warn!(dc = %dc, "All ME servers for DC failed at init"); false } pub(crate) async fn remove_writer_and_close_clients(&self, writer_id: u64) { let conns = self.remove_writer_only(writer_id).await; for bound in conns { let _ = self.registry.route(bound.conn_id, super::MeResponse::Close).await; let _ = self.registry.unregister(bound.conn_id).await; } } async fn remove_writer_only(&self, writer_id: u64) -> Vec { { let mut ws = self.writers.write().await; if let Some(pos) = ws.iter().position(|w| w.id == writer_id) { let w = ws.remove(pos); w.cancel.cancel(); } } self.registry.writer_lost(writer_id).await } pub(crate) async fn mark_writer_draining(self: &Arc, writer_id: u64) { { let mut ws = self.writers.write().await; if let Some(w) = ws.iter_mut().find(|w| w.id == writer_id) { w.draining.store(true, Ordering::Relaxed); } } let pool = Arc::downgrade(self); tokio::spawn(async move { loop { if let Some(p) = pool.upgrade() { if p.registry.is_writer_empty(writer_id).await { let _ = p.remove_writer_only(writer_id).await; break; } tokio::time::sleep(Duration::from_secs(1)).await; } else { break; } } }); } } fn hex_dump(data: &[u8]) -> String { const MAX: usize = 64; let mut out = String::with_capacity(data.len() * 2 + 3); for (i, b) in data.iter().take(MAX).enumerate() { if i > 0 { out.push(' '); } out.push_str(&format!("{b:02x}")); } if data.len() > MAX { out.push_str(" …"); } out }