diff --git a/src/transport/middle_proxy/handshake.rs b/src/transport/middle_proxy/handshake.rs index 0a410c8..a860d01 100644 --- a/src/transport/middle_proxy/handshake.rs +++ b/src/transport/middle_proxy/handshake.rs @@ -16,6 +16,7 @@ use tracing::{debug, info, warn}; use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_keys, sha256}; use crate::error::{ProxyError, Result}; +use crate::network::IpFamily; use crate::protocol::constants::{ ME_CONNECT_TIMEOUT_SECS, ME_HANDSHAKE_TIMEOUT_SECS, RPC_CRYPTO_AES_U32, RPC_HANDSHAKE_ERROR_U32, RPC_HANDSHAKE_U32, RPC_PING_U32, RPC_PONG_U32, RPC_NONCE_U32, @@ -101,8 +102,13 @@ impl MePool { let peer_addr = stream.peer_addr().map_err(ProxyError::Io)?; let _ = self.maybe_detect_nat_ip(local_addr.ip()).await; + let family = if local_addr.ip().is_ipv4() { + IpFamily::V4 + } else { + IpFamily::V6 + }; let reflected = if self.nat_probe { - self.maybe_reflect_public_addr().await + self.maybe_reflect_public_addr(family).await } else { None }; diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index d2bb51a..348e1d7 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -7,107 +7,84 @@ use tracing::{debug, info, warn}; use rand::seq::SliceRandom; use crate::crypto::SecureRandom; +use crate::network::IpFamily; use super::MePool; pub async fn me_health_monitor(pool: Arc, rng: Arc, _min_connections: usize) { - let mut backoff: HashMap = HashMap::new(); - let mut last_attempt: HashMap = HashMap::new(); + let mut backoff: HashMap<(i32, IpFamily), u64> = HashMap::new(); + let mut last_attempt: HashMap<(i32, IpFamily), Instant> = HashMap::new(); loop { tokio::time::sleep(Duration::from_secs(30)).await; - // Per-DC coverage check - let map = pool.proxy_map_v4.read().await.clone(); - let writer_addrs: std::collections::HashSet = pool - .writers - .read() - .await - .iter() - .map(|w| w.addr) - .collect(); + check_family(IpFamily::V4, &pool, &rng, &mut backoff, &mut last_attempt).await; + check_family(IpFamily::V6, &pool, &rng, &mut backoff, &mut last_attempt).await; + } +} - for (dc, addrs) in map.iter() { - let dc_addrs: Vec = addrs - .iter() - .map(|(ip, port)| SocketAddr::new(*ip, *port)) - .collect(); - let has_coverage = dc_addrs.iter().any(|a| writer_addrs.contains(a)); - if !has_coverage { - let delay = *backoff.get(dc).unwrap_or(&30); - let now = Instant::now(); - if let Some(last) = last_attempt.get(dc) { - if now.duration_since(*last).as_secs() < delay { - continue; - } - } - warn!(dc = %dc, delay, "DC has no ME coverage, reconnecting..."); - let mut shuffled = dc_addrs.clone(); - shuffled.shuffle(&mut rand::rng()); - let mut reconnected = false; - for addr in shuffled { - match pool.connect_one(addr, &rng).await { - Ok(()) => { - info!(%addr, dc = %dc, "ME reconnected for DC coverage"); - backoff.insert(*dc, 30); - last_attempt.insert(*dc, now); - reconnected = true; - break; - } - Err(e) => debug!(%addr, dc = %dc, error = %e, "ME reconnect failed"), - } - } - if !reconnected { - let next = (*backoff.get(dc).unwrap_or(&30)).saturating_mul(2).min(300); - backoff.insert(*dc, next); - last_attempt.insert(*dc, now); - } +async fn check_family( + family: IpFamily, + pool: &Arc, + rng: &Arc, + backoff: &mut HashMap<(i32, IpFamily), u64>, + last_attempt: &mut HashMap<(i32, IpFamily), Instant>, +) { + let enabled = match family { + IpFamily::V4 => pool.decision.ipv4_me, + IpFamily::V6 => pool.decision.ipv6_me, + }; + if !enabled { + return; + } + + let map = match family { + IpFamily::V4 => pool.proxy_map_v4.read().await.clone(), + IpFamily::V6 => pool.proxy_map_v6.read().await.clone(), + }; + let writer_addrs: HashSet = pool + .writers + .read() + .await + .iter() + .map(|w| w.addr) + .collect(); + + for (dc, addrs) in map.iter() { + let dc_addrs: Vec = addrs + .iter() + .map(|(ip, port)| SocketAddr::new(*ip, *port)) + .collect(); + let has_coverage = dc_addrs.iter().any(|a| writer_addrs.contains(a)); + if has_coverage { + continue; + } + let key = (*dc, family); + let delay = *backoff.get(&key).unwrap_or(&30); + let now = Instant::now(); + if let Some(last) = last_attempt.get(&key) { + if now.duration_since(*last).as_secs() < delay { + continue; } } - - // IPv6 coverage check (if available) - let map_v6 = pool.proxy_map_v6.read().await.clone(); - let writer_addrs_v6: std::collections::HashSet = pool - .writers - .read() - .await - .iter() - .map(|w| w.addr) - .collect(); - for (dc, addrs) in map_v6.iter() { - let dc_addrs: Vec = addrs - .iter() - .map(|(ip, port)| SocketAddr::new(*ip, *port)) - .collect(); - let has_coverage = dc_addrs.iter().any(|a| writer_addrs_v6.contains(a)); - if !has_coverage { - let delay = *backoff.get(dc).unwrap_or(&30); - let now = Instant::now(); - if let Some(last) = last_attempt.get(dc) { - if now.duration_since(*last).as_secs() < delay { - continue; - } - } - warn!(dc = %dc, delay, "IPv6 DC has no ME coverage, reconnecting..."); - let mut shuffled = dc_addrs.clone(); - shuffled.shuffle(&mut rand::rng()); - let mut reconnected = false; - for addr in shuffled { - match pool.connect_one(addr, &rng).await { - Ok(()) => { - info!(%addr, dc = %dc, "ME reconnected for IPv6 DC coverage"); - backoff.insert(*dc, 30); - last_attempt.insert(*dc, now); - reconnected = true; - break; - } - Err(e) => debug!(%addr, dc = %dc, error = %e, "ME reconnect failed (IPv6)"), - } - } - if !reconnected { - let next = (*backoff.get(dc).unwrap_or(&30)).saturating_mul(2).min(300); - backoff.insert(*dc, next); - last_attempt.insert(*dc, now); + warn!(dc = %dc, delay, ?family, "DC has no ME coverage, reconnecting..."); + let mut shuffled = dc_addrs.clone(); + shuffled.shuffle(&mut rand::rng()); + let mut reconnected = false; + for addr in shuffled { + match pool.connect_one(addr, rng.as_ref()).await { + Ok(()) => { + info!(%addr, dc = %dc, ?family, "ME reconnected for DC coverage"); + backoff.insert(key, 30); + last_attempt.insert(key, now); + reconnected = true; + break; } + Err(e) => debug!(%addr, dc = %dc, error = %e, ?family, "ME reconnect failed"), } } + if !reconnected { + let next = (*backoff.get(&key).unwrap_or(&30)).saturating_mul(2).min(300); + backoff.insert(key, next); + last_attempt.insert(key, now); + } } } diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs index 26d07dd..1027221 100644 --- a/src/transport/middle_proxy/mod.rs +++ b/src/transport/middle_proxy/mod.rs @@ -19,7 +19,7 @@ use bytes::Bytes; pub use health::me_health_monitor; pub use ping::{run_me_ping, format_sample_line, MePingReport, MePingSample, MePingFamily}; pub use pool::MePool; -pub use pool_nat::{stun_probe, detect_public_ip, StunProbeResult}; +pub use pool_nat::{stun_probe, detect_public_ip}; pub use registry::ConnRegistry; pub use secret::fetch_proxy_secret; pub use config_updater::{fetch_proxy_config, me_config_updater}; diff --git a/src/transport/middle_proxy/ping.rs b/src/transport/middle_proxy/ping.rs index 22b1f6d..36ef4e7 100644 --- a/src/transport/middle_proxy/ping.rs +++ b/src/transport/middle_proxy/ping.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; @@ -92,8 +93,16 @@ mod tests { pub async fn run_me_ping(pool: &Arc, rng: &SecureRandom) -> Vec { let mut reports = Vec::new(); - let v4_map = pool.proxy_map_v4.read().await.clone(); - let v6_map = pool.proxy_map_v6.read().await.clone(); + let v4_map = if pool.decision.ipv4_me { + pool.proxy_map_v4.read().await.clone() + } else { + HashMap::new() + }; + let v6_map = if pool.decision.ipv6_me { + pool.proxy_map_v6.read().await.clone() + } else { + HashMap::new() + }; let mut grouped: Vec<(MePingFamily, i32, Vec<(IpAddr, u16)>)> = Vec::new(); for (dc, addrs) in v4_map { diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 7305f5e..9771b6b 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -12,6 +12,8 @@ use std::time::Duration; use crate::crypto::SecureRandom; use crate::error::{ProxyError, Result}; +use crate::network::probe::NetworkDecision; +use crate::network::IpFamily; use crate::protocol::constants::*; use super::ConnRegistry; @@ -36,6 +38,8 @@ pub struct MePool { pub(super) registry: Arc, pub(super) writers: Arc>>, pub(super) rr: AtomicU64, + pub(super) decision: NetworkDecision, + pub(super) rng: Arc, pub(super) proxy_tag: Option>, pub(super) proxy_secret: Arc>>, pub(super) nat_ip_cfg: Option, @@ -48,10 +52,16 @@ pub struct MePool { pub(super) next_writer_id: AtomicU64, pub(super) ping_tracker: Arc>>, pub(super) rtt_stats: Arc>>, - pub(super) nat_reflection_cache: Arc>>, + pub(super) nat_reflection_cache: Arc>, pool_size: usize, } +#[derive(Debug, Default)] +pub struct NatReflectionCache { + pub v4: Option<(std::time::Instant, std::net::SocketAddr)>, + pub v6: Option<(std::time::Instant, std::net::SocketAddr)>, +} + impl MePool { pub fn new( proxy_tag: Option>, @@ -62,11 +72,15 @@ impl MePool { proxy_map_v4: HashMap>, proxy_map_v6: HashMap>, default_dc: Option, + decision: NetworkDecision, + rng: Arc, ) -> Arc { Arc::new(Self { registry: Arc::new(ConnRegistry::new()), writers: Arc::new(RwLock::new(Vec::new())), rr: AtomicU64::new(0), + decision, + rng, proxy_tag, proxy_secret: Arc::new(RwLock::new(proxy_secret)), nat_ip_cfg: nat_ip, @@ -80,7 +94,7 @@ impl MePool { next_writer_id: AtomicU64::new(1), ping_tracker: Arc::new(Mutex::new(HashMap::new())), rtt_stats: Arc::new(Mutex::new(HashMap::new())), - nat_reflection_cache: Arc::new(Mutex::new(None)), + nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())), }) } @@ -103,29 +117,30 @@ impl MePool { pub async fn reconcile_connections(self: &Arc, rng: &SecureRandom) { use std::collections::HashSet; - let map = self.proxy_map_v4.read().await.clone(); - let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map - .iter() - .map(|(dc, addrs)| (*dc, addrs.clone())) - .collect(); let writers = self.writers.read().await; let current: HashSet = writers.iter().map(|w| w.addr).collect(); drop(writers); - for (_dc, addrs) in map.iter() { - let dc_addrs: Vec = addrs - .iter() - .map(|(ip, port)| SocketAddr::new(*ip, *port)) - .collect(); - if !dc_addrs.iter().any(|a| current.contains(a)) { - let mut shuffled = dc_addrs.clone(); - shuffled.shuffle(&mut rand::rng()); - for addr in shuffled { - if self.connect_one(addr, rng).await.is_ok() { - break; + for family in self.family_order() { + let map = self.proxy_map_for_family(family).await; + for (_dc, addrs) in map.iter() { + let dc_addrs: Vec = addrs + .iter() + .map(|(ip, port)| SocketAddr::new(*ip, *port)) + .collect(); + if !dc_addrs.iter().any(|a| current.contains(a)) { + let mut shuffled = dc_addrs.clone(); + shuffled.shuffle(&mut rand::rng()); + for addr in shuffled { + if self.connect_one(addr, rng).await.is_ok() { + break; + } } } } + if !self.decision.effective_multipath && !current.is_empty() { + break; + } } } @@ -181,47 +196,82 @@ impl MePool { } } + pub(super) fn family_order(&self) -> Vec { + let mut order = Vec::new(); + if self.decision.prefer_ipv6() { + if self.decision.ipv6_me { + order.push(IpFamily::V6); + } + if self.decision.ipv4_me { + order.push(IpFamily::V4); + } + } else { + if self.decision.ipv4_me { + order.push(IpFamily::V4); + } + if self.decision.ipv6_me { + order.push(IpFamily::V6); + } + } + order + } + + async fn proxy_map_for_family(&self, family: IpFamily) -> HashMap> { + match family { + IpFamily::V4 => self.proxy_map_v4.read().await.clone(), + IpFamily::V6 => self.proxy_map_v6.read().await.clone(), + } + } + pub async fn init(self: &Arc, pool_size: usize, rng: &Arc) -> Result<()> { - let map = self.proxy_map_v4.read().await.clone(); - let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map - .iter() - .map(|(dc, addrs)| (*dc, addrs.clone())) - .collect(); + let family_order = self.family_order(); let ks = self.key_selector().await; info!( - me_servers = map.len(), + me_servers = self.proxy_map_v4.read().await.len(), pool_size, key_selector = format_args!("0x{ks:08x}"), secret_len = self.proxy_secret.read().await.len(), "Initializing ME pool" ); - // Ensure at least one connection per DC; run DCs in parallel. - let mut join = tokio::task::JoinSet::new(); - for (dc, addrs) in dc_addrs.iter().cloned() { - if addrs.is_empty() { - continue; - } - let pool = Arc::clone(self); - let rng_clone = Arc::clone(rng); - join.spawn(async move { - pool.connect_primary_for_dc(dc, addrs, rng_clone).await; - }); - } - while let Some(_res) = join.join_next().await {} + for family in family_order { + let map = self.proxy_map_for_family(family).await; + let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map + .iter() + .map(|(dc, addrs)| (*dc, addrs.clone())) + .collect(); - // Additional connections up to pool_size total (round-robin across DCs) - for (dc, addrs) in dc_addrs.iter() { - for (ip, port) in addrs { + // Ensure at least one connection per DC; run DCs in parallel. + let mut join = tokio::task::JoinSet::new(); + for (dc, addrs) in dc_addrs.iter().cloned() { + if addrs.is_empty() { + continue; + } + let pool = Arc::clone(self); + let rng_clone = Arc::clone(rng); + join.spawn(async move { + pool.connect_primary_for_dc(dc, addrs, rng_clone).await; + }); + } + while let Some(_res) = join.join_next().await {} + + // Additional connections up to pool_size total (round-robin across DCs) + for (dc, addrs) in dc_addrs.iter() { + for (ip, port) in addrs { + if self.connection_count() >= pool_size { + break; + } + let addr = SocketAddr::new(*ip, *port); + if let Err(e) = self.connect_one(addr, rng.as_ref()).await { + debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed"); + } + } if self.connection_count() >= pool_size { break; } - let addr = SocketAddr::new(*ip, *port); - if let Err(e) = self.connect_one(addr, rng.as_ref()).await { - debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed"); - } } - if self.connection_count() >= pool_size { + + if !self.decision.effective_multipath && self.connection_count() > 0 { break; } } @@ -309,14 +359,15 @@ impl MePool { } _ = tokio::time::sleep(Duration::from_secs(wait)) => {} } + let sent_id = ping_id; let mut p = Vec::with_capacity(12); p.extend_from_slice(&RPC_PING_U32.to_le_bytes()); - p.extend_from_slice(&ping_id.to_le_bytes()); - ping_id = ping_id.wrapping_add(1); + p.extend_from_slice(&sent_id.to_le_bytes()); { let mut tracker = ping_tracker_ping.lock().await; - tracker.insert(ping_id, (std::time::Instant::now(), writer_id)); + tracker.insert(sent_id, (std::time::Instant::now(), writer_id)); } + ping_id = ping_id.wrapping_add(1); if let Err(e) = rpc_w_ping.lock().await.send(&p).await { debug!(error = %e, "Active ME ping failed, removing dead writer"); cancel_ping.cancel(); diff --git a/src/transport/middle_proxy/pool_nat.rs b/src/transport/middle_proxy/pool_nat.rs index 3a69118..db34e09 100644 --- a/src/transport/middle_proxy/pool_nat.rs +++ b/src/transport/middle_proxy/pool_nat.rs @@ -4,19 +4,14 @@ use std::time::Duration; use tracing::{info, warn}; use crate::error::{ProxyError, Result}; +use crate::network::probe::is_bogon; +use crate::network::stun::{stun_probe_dual, IpFamily, StunProbeResult}; use super::MePool; use std::time::Instant; - -#[derive(Debug, Clone, Copy)] -pub struct StunProbeResult { - pub local_addr: std::net::SocketAddr, - pub reflected_addr: std::net::SocketAddr, -} - -pub async fn stun_probe(stun_addr: Option) -> Result> { +pub async fn stun_probe(stun_addr: Option) -> Result { let stun_addr = stun_addr.unwrap_or_else(|| "stun.l.google.com:19302".to_string()); - fetch_stun_binding(&stun_addr).await + stun_probe_dual(&stun_addr).await } pub async fn detect_public_ip() -> Option { @@ -35,7 +30,7 @@ impl MePool { match (ip, nat_ip) { (IpAddr::V4(src), IpAddr::V4(dst)) - if is_privateish(IpAddr::V4(src)) + if is_bogon(IpAddr::V4(src)) || src.is_loopback() || src.is_unspecified() => { @@ -55,7 +50,7 @@ impl MePool { ) -> std::net::SocketAddr { let ip = if let Some(r) = reflected { // Use reflected IP (not port) only when local address is non-public. - if is_privateish(addr.ip()) || addr.ip().is_loopback() || addr.ip().is_unspecified() { + if is_bogon(addr.ip()) || addr.ip().is_loopback() || addr.ip().is_unspecified() { r.ip() } else { self.translate_ip_for_nat(addr.ip()) @@ -73,7 +68,7 @@ impl MePool { return self.nat_ip_cfg; } - if !(is_privateish(local_ip) || local_ip.is_loopback() || local_ip.is_unspecified()) { + if !(is_bogon(local_ip) || local_ip.is_loopback() || local_ip.is_unspecified()) { return None; } @@ -98,12 +93,19 @@ impl MePool { } } - pub(super) async fn maybe_reflect_public_addr(&self) -> Option { + pub(super) async fn maybe_reflect_public_addr( + &self, + family: IpFamily, + ) -> Option { const STUN_CACHE_TTL: Duration = Duration::from_secs(600); if let Ok(mut cache) = self.nat_reflection_cache.try_lock() { - if let Some((ts, addr)) = *cache { + let slot = match family { + IpFamily::V4 => &mut cache.v4, + IpFamily::V6 => &mut cache.v6, + }; + if let Some((ts, addr)) = slot { if ts.elapsed() < STUN_CACHE_TTL { - return Some(addr); + return Some(*addr); } } } @@ -112,12 +114,20 @@ impl MePool { .nat_stun .clone() .unwrap_or_else(|| "stun.l.google.com:19302".to_string()); - match fetch_stun_binding(&stun_addr).await { - Ok(sa) => { - if let Some(result) = sa { - info!(local = %result.local_addr, reflected = %result.reflected_addr, "NAT probe: reflected address"); + match stun_probe_dual(&stun_addr).await { + Ok(res) => { + let picked: Option = match family { + IpFamily::V4 => res.v4, + IpFamily::V6 => res.v6, + }; + if let Some(result) = picked { + info!(local = %result.local_addr, reflected = %result.reflected_addr, family = ?family, "NAT probe: reflected address"); if let Ok(mut cache) = self.nat_reflection_cache.try_lock() { - *cache = Some((Instant::now(), result.reflected_addr)); + let slot = match family { + IpFamily::V4 => &mut cache.v4, + IpFamily::V6 => &mut cache.v6, + }; + *slot = Some((Instant::now(), result.reflected_addr)); } Some(result.reflected_addr) } else { @@ -158,98 +168,3 @@ async fn fetch_public_ipv4_once(url: &str) -> Result> { let ip = text.trim().parse().ok(); Ok(ip) } - -async fn fetch_stun_binding(stun_addr: &str) -> Result> { - use rand::RngCore; - use tokio::net::UdpSocket; - - let socket = UdpSocket::bind("0.0.0.0:0") - .await - .map_err(|e| ProxyError::Proxy(format!("STUN bind failed: {e}")))?; - socket - .connect(stun_addr) - .await - .map_err(|e| ProxyError::Proxy(format!("STUN connect failed: {e}")))?; - - // Build minimal Binding Request. - let mut req = vec![0u8; 20]; - req[0..2].copy_from_slice(&0x0001u16.to_be_bytes()); // Binding Request - req[2..4].copy_from_slice(&0u16.to_be_bytes()); // length - req[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes()); // magic cookie - rand::rng().fill_bytes(&mut req[8..20]); - - socket - .send(&req) - .await - .map_err(|e| ProxyError::Proxy(format!("STUN send failed: {e}")))?; - - let mut buf = [0u8; 128]; - let n = socket - .recv(&mut buf) - .await - .map_err(|e| ProxyError::Proxy(format!("STUN recv failed: {e}")))?; - if n < 20 { - return Ok(None); - } - - // Parse attributes. - let mut idx = 20; - while idx + 4 <= n { - let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().unwrap()); - let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().unwrap()) as usize; - idx += 4; - if idx + alen > n { - break; - } - match atype { - 0x0020 /* XOR-MAPPED-ADDRESS */ | 0x0001 /* MAPPED-ADDRESS */ => { - if alen < 8 { - break; - } - let family = buf[idx + 1]; - if family != 0x01 { - // only IPv4 supported here - break; - } - let port_bytes = [buf[idx + 2], buf[idx + 3]]; - let ip_bytes = [buf[idx + 4], buf[idx + 5], buf[idx + 6], buf[idx + 7]]; - - let (port, ip) = if atype == 0x0020 { - let magic = 0x2112A442u32.to_be_bytes(); - let port = u16::from_be_bytes(port_bytes) ^ ((magic[0] as u16) << 8 | magic[1] as u16); - let ip = [ - ip_bytes[0] ^ magic[0], - ip_bytes[1] ^ magic[1], - ip_bytes[2] ^ magic[2], - ip_bytes[3] ^ magic[3], - ]; - (port, ip) - } else { - (u16::from_be_bytes(port_bytes), ip_bytes) - }; - let reflected = std::net::SocketAddr::new( - IpAddr::V4(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3])), - port, - ); - let local_addr = socket.local_addr().map_err(|e| { - ProxyError::Proxy(format!("STUN local_addr failed: {e}")) - })?; - return Ok(Some(StunProbeResult { - local_addr, - reflected_addr: reflected, - })); - } - _ => {} - } - idx += (alen + 3) & !3; // 4-byte alignment - } - - Ok(None) -} - -fn is_privateish(ip: IpAddr) -> bool { - match ip { - IpAddr::V4(v4) => v4.is_private() || v4.is_link_local(), - IpAddr::V6(v6) => v6.is_unique_local(), - } -} diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index 3ca02d5..b53ddef 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -152,6 +152,9 @@ pub(crate) async fn reader_loop( entry.1 = entry.1 * 0.8 + rtt * 0.2; if rtt < entry.0 { entry.0 = rtt; + } else { + // allow slow baseline drift upward to avoid stale minimum + entry.0 = entry.0 * 0.99 + rtt * 0.01; } let degraded_now = entry.1 > entry.0 * 2.0; degraded.store(degraded_now, Ordering::Relaxed); diff --git a/src/transport/middle_proxy/rotation.rs b/src/transport/middle_proxy/rotation.rs index 5313bdb..5457f70 100644 --- a/src/transport/middle_proxy/rotation.rs +++ b/src/transport/middle_proxy/rotation.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use std::sync::atomic::Ordering; use std::time::Duration; use tracing::{info, warn}; @@ -15,7 +16,12 @@ pub async fn me_rotation_task(pool: Arc, rng: Arc, interva let candidate = { let ws = pool.writers.read().await; - ws.get(0).cloned() + if ws.is_empty() { + None + } else { + let idx = (pool.rr.load(std::sync::atomic::Ordering::Relaxed) as usize) % ws.len(); + ws.get(idx).cloned() + } }; let Some(w) = candidate else { @@ -34,4 +40,3 @@ pub async fn me_rotation_task(pool: Arc, rng: Arc, interva } } } - diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index 174127d..5eaacf0 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -3,15 +3,14 @@ use std::sync::Arc; use std::sync::atomic::Ordering; use std::time::Duration; -use tokio::sync::Mutex; use tracing::{debug, warn}; use crate::error::{ProxyError, Result}; +use crate::network::IpFamily; use crate::protocol::constants::RPC_CLOSE_EXT_U32; use super::MePool; use super::wire::build_proxy_req_payload; -use crate::crypto::SecureRandom; use rand::seq::SliceRandom; use super::registry::ConnMeta; @@ -84,7 +83,7 @@ impl MePool { drop(map); for (ip, port) in shuffled { let addr = SocketAddr::new(ip, port); - if self.connect_one(addr, &SecureRandom::new()).await.is_ok() { + if self.connect_one(addr, self.rng.as_ref()).await.is_ok() { break; } } @@ -173,32 +172,44 @@ impl MePool { writers: &[super::pool::MeWriter], target_dc: i16, ) -> Vec { - let mut preferred = Vec::::new(); let key = target_dc as i32; - let map = self.proxy_map_v4.read().await; + let mut preferred = Vec::::new(); - if let Some(v) = map.get(&key) { - preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port))); - } - if preferred.is_empty() { - let abs = key.abs(); - if let Some(v) = map.get(&abs) { + for family in self.family_order() { + let map_guard = match family { + IpFamily::V4 => self.proxy_map_v4.read().await, + IpFamily::V6 => self.proxy_map_v6.read().await, + }; + + if let Some(v) = map_guard.get(&key) { preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port))); } - } - if preferred.is_empty() { - let abs = key.abs(); - if let Some(v) = map.get(&-abs) { - preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port))); - } - } - if preferred.is_empty() { - let def = self.default_dc.load(Ordering::Relaxed); - if def != 0 { - if let Some(v) = map.get(&def) { + if preferred.is_empty() { + let abs = key.abs(); + if let Some(v) = map_guard.get(&abs) { preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port))); } } + if preferred.is_empty() { + let abs = key.abs(); + if let Some(v) = map_guard.get(&-abs) { + preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port))); + } + } + if preferred.is_empty() { + let def = self.default_dc.load(Ordering::Relaxed); + if def != 0 { + if let Some(v) = map_guard.get(&def) { + preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port))); + } + } + } + + drop(map_guard); + + if !preferred.is_empty() && !self.decision.effective_multipath { + break; + } } if preferred.is_empty() {