diff --git a/Cargo.toml b/Cargo.toml index c0f43db..4bb5172 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,6 +53,7 @@ reqwest = { version = "0.12", features = ["rustls-tls"], default-features = fals hyper = { version = "1", features = ["server", "http1"] } hyper-util = { version = "0.1", features = ["tokio", "server-auto"] } http-body-util = "0.1" +httpdate = "1.0" [dev-dependencies] tokio-test = "0.4" diff --git a/src/main.rs b/src/main.rs index b820785..2dd9a56 100644 --- a/src/main.rs +++ b/src/main.rs @@ -256,10 +256,22 @@ async fn main() -> std::result::Result<(), Box> { if probe.local_addr.ip() != probe.reflected_addr.ip() && !config.general.stun_iface_mismatch_ignore { - warn!( - "STUN/IP-on-Interface mismatch -> fallback to direct-DC" - ); - use_middle_proxy = false; + match crate::transport::middle_proxy::detect_public_ip().await { + Some(ip) => { + info!( + local_ip = %probe.local_addr.ip(), + reflected_ip = %probe.reflected_addr.ip(), + public_ip = %ip, + "STUN mismatch but public IP auto-detected, continuing with middle proxy" + ); + } + None => { + warn!( + "STUN/IP-on-Interface mismatch and public IP auto-detect failed -> fallback to direct-DC" + ); + use_middle_proxy = false; + } + } } } Ok(None) => warn!("STUN probe returned no address; continuing"), @@ -355,6 +367,18 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai .await; }); + // Periodic ME connection rotation + let pool_clone_rot = pool.clone(); + let rng_clone_rot = rng.clone(); + tokio::spawn(async move { + crate::transport::middle_proxy::me_rotation_task( + pool_clone_rot, + rng_clone_rot, + std::time::Duration::from_secs(1800), + ) + .await; + }); + // Periodic updater: getProxyConfig + proxy-secret let pool_clone2 = pool.clone(); let rng_clone2 = rng.clone(); diff --git a/src/transport/middle_proxy/codec.rs b/src/transport/middle_proxy/codec.rs index 51daee9..326bf90 100644 --- a/src/transport/middle_proxy/codec.rs +++ b/src/transport/middle_proxy/codec.rs @@ -174,6 +174,7 @@ impl RpcWriter { if buf.len() >= 16 { self.iv.copy_from_slice(&buf[buf.len() - 16..]); } - self.writer.write_all(&buf).await.map_err(ProxyError::Io) + self.writer.write_all(&buf).await.map_err(ProxyError::Io)?; + self.writer.flush().await.map_err(ProxyError::Io) } } diff --git a/src/transport/middle_proxy/config_updater.rs b/src/transport/middle_proxy/config_updater.rs index aed5a54..8ac6986 100644 --- a/src/transport/middle_proxy/config_updater.rs +++ b/src/transport/middle_proxy/config_updater.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use std::time::Duration; use regex::Regex; +use httpdate; use tracing::{debug, info, warn}; use crate::error::Result; @@ -11,6 +12,7 @@ use crate::error::Result; use super::MePool; use super::secret::download_proxy_secret; use crate::crypto::SecureRandom; +use std::time::SystemTime; #[derive(Debug, Clone, Default)] pub struct ProxyConfigData { @@ -19,9 +21,29 @@ pub struct ProxyConfigData { } pub async fn fetch_proxy_config(url: &str) -> Result { - let text = reqwest::get(url) + let resp = reqwest::get(url) .await .map_err(|e| crate::error::ProxyError::Proxy(format!("fetch_proxy_config GET failed: {e}")))? + ; + + if let Some(date) = resp.headers().get(reqwest::header::DATE) { + if let Ok(date_str) = date.to_str() { + if let Ok(server_time) = httpdate::parse_http_date(date_str) { + if let Ok(skew) = SystemTime::now().duration_since(server_time).or_else(|e| { + server_time.duration_since(SystemTime::now()).map_err(|_| e) + }) { + let skew_secs = skew.as_secs(); + if skew_secs > 60 { + warn!(skew_secs, "Time skew >60s detected from fetch_proxy_config Date header"); + } else if skew_secs > 30 { + warn!(skew_secs, "Time skew >30s detected from fetch_proxy_config Date header"); + } + } + } + } + } + + let text = resp .text() .await .map_err(|e| crate::error::ProxyError::Proxy(format!("fetch_proxy_config read failed: {e}")))?; diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index 31e5030..d2bb51a 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -1,6 +1,7 @@ +use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, Instant}; use tracing::{debug, info, warn}; use rand::seq::SliceRandom; @@ -10,6 +11,8 @@ use crate::crypto::SecureRandom; use super::MePool; pub async fn me_health_monitor(pool: Arc, rng: Arc, _min_connections: usize) { + let mut backoff: HashMap = HashMap::new(); + let mut last_attempt: HashMap = HashMap::new(); loop { tokio::time::sleep(Duration::from_secs(30)).await; // Per-DC coverage check @@ -29,18 +32,81 @@ pub async fn me_health_monitor(pool: Arc, rng: Arc, _min_c .collect(); let has_coverage = dc_addrs.iter().any(|a| writer_addrs.contains(a)); if !has_coverage { - warn!(dc = %dc, "DC has no ME coverage, reconnecting..."); + let delay = *backoff.get(dc).unwrap_or(&30); + let now = Instant::now(); + if let Some(last) = last_attempt.get(dc) { + if now.duration_since(*last).as_secs() < delay { + continue; + } + } + warn!(dc = %dc, delay, "DC has no ME coverage, reconnecting..."); let mut shuffled = dc_addrs.clone(); shuffled.shuffle(&mut rand::rng()); + let mut reconnected = false; for addr in shuffled { match pool.connect_one(addr, &rng).await { Ok(()) => { info!(%addr, dc = %dc, "ME reconnected for DC coverage"); + backoff.insert(*dc, 30); + last_attempt.insert(*dc, now); + reconnected = true; break; } Err(e) => debug!(%addr, dc = %dc, error = %e, "ME reconnect failed"), } } + if !reconnected { + let next = (*backoff.get(dc).unwrap_or(&30)).saturating_mul(2).min(300); + backoff.insert(*dc, next); + last_attempt.insert(*dc, now); + } + } + } + + // IPv6 coverage check (if available) + let map_v6 = pool.proxy_map_v6.read().await.clone(); + let writer_addrs_v6: std::collections::HashSet = pool + .writers + .read() + .await + .iter() + .map(|w| w.addr) + .collect(); + for (dc, addrs) in map_v6.iter() { + let dc_addrs: Vec = addrs + .iter() + .map(|(ip, port)| SocketAddr::new(*ip, *port)) + .collect(); + let has_coverage = dc_addrs.iter().any(|a| writer_addrs_v6.contains(a)); + if !has_coverage { + let delay = *backoff.get(dc).unwrap_or(&30); + let now = Instant::now(); + if let Some(last) = last_attempt.get(dc) { + if now.duration_since(*last).as_secs() < delay { + continue; + } + } + warn!(dc = %dc, delay, "IPv6 DC has no ME coverage, reconnecting..."); + let mut shuffled = dc_addrs.clone(); + shuffled.shuffle(&mut rand::rng()); + let mut reconnected = false; + for addr in shuffled { + match pool.connect_one(addr, &rng).await { + Ok(()) => { + info!(%addr, dc = %dc, "ME reconnected for IPv6 DC coverage"); + backoff.insert(*dc, 30); + last_attempt.insert(*dc, now); + reconnected = true; + break; + } + Err(e) => debug!(%addr, dc = %dc, error = %e, "ME reconnect failed (IPv6)"), + } + } + if !reconnected { + let next = (*backoff.get(dc).unwrap_or(&30)).saturating_mul(2).min(300); + backoff.insert(*dc, next); + last_attempt.insert(*dc, now); + } } } } diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs index 443c189..26d07dd 100644 --- a/src/transport/middle_proxy/mod.rs +++ b/src/transport/middle_proxy/mod.rs @@ -10,6 +10,7 @@ mod reader; mod registry; mod send; mod secret; +mod rotation; mod config_updater; mod wire; @@ -18,10 +19,11 @@ use bytes::Bytes; pub use health::me_health_monitor; pub use ping::{run_me_ping, format_sample_line, MePingReport, MePingSample, MePingFamily}; pub use pool::MePool; -pub use pool_nat::{stun_probe, StunProbeResult}; +pub use pool_nat::{stun_probe, detect_public_ip, StunProbeResult}; pub use registry::ConnRegistry; pub use secret::fetch_proxy_secret; pub use config_updater::{fetch_proxy_config, me_config_updater}; +pub use rotation::me_rotation_task; pub use wire::proto_flags_for_tag; #[derive(Debug)] diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 873a65f..7305f5e 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -48,6 +48,7 @@ pub struct MePool { pub(super) next_writer_id: AtomicU64, pub(super) ping_tracker: Arc>>, pub(super) rtt_stats: Arc>>, + pub(super) nat_reflection_cache: Arc>>, pool_size: usize, } @@ -79,6 +80,7 @@ impl MePool { next_writer_id: AtomicU64::new(1), ping_tracker: Arc::new(Mutex::new(HashMap::new())), rtt_stats: Arc::new(Mutex::new(HashMap::new())), + nat_reflection_cache: Arc::new(Mutex::new(None)), }) } diff --git a/src/transport/middle_proxy/pool_nat.rs b/src/transport/middle_proxy/pool_nat.rs index 35ee0ea..3a69118 100644 --- a/src/transport/middle_proxy/pool_nat.rs +++ b/src/transport/middle_proxy/pool_nat.rs @@ -1,10 +1,12 @@ use std::net::{IpAddr, Ipv4Addr}; +use std::time::Duration; use tracing::{info, warn}; use crate::error::{ProxyError, Result}; use super::MePool; +use std::time::Instant; #[derive(Debug, Clone, Copy)] pub struct StunProbeResult { @@ -17,6 +19,10 @@ pub async fn stun_probe(stun_addr: Option) -> Result Option { + fetch_public_ipv4_with_retry().await.ok().flatten().map(IpAddr::V4) +} + impl MePool { pub(super) fn translate_ip_for_nat(&self, ip: IpAddr) -> IpAddr { let nat_ip = self @@ -93,6 +99,15 @@ impl MePool { } pub(super) async fn maybe_reflect_public_addr(&self) -> Option { + const STUN_CACHE_TTL: Duration = Duration::from_secs(600); + if let Ok(mut cache) = self.nat_reflection_cache.try_lock() { + if let Some((ts, addr)) = *cache { + if ts.elapsed() < STUN_CACHE_TTL { + return Some(addr); + } + } + } + let stun_addr = self .nat_stun .clone() @@ -101,6 +116,9 @@ impl MePool { Ok(sa) => { if let Some(result) = sa { info!(local = %result.local_addr, reflected = %result.reflected_addr, "NAT probe: reflected address"); + if let Ok(mut cache) = self.nat_reflection_cache.try_lock() { + *cache = Some((Instant::now(), result.reflected_addr)); + } Some(result.reflected_addr) } else { None diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index 58e7bfb..3ca02d5 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -32,6 +32,7 @@ pub(crate) async fn reader_loop( cancel: CancellationToken, ) -> Result<()> { let mut raw = enc_leftover; + let mut expected_seq: i32 = 0; loop { let mut tmp = [0u8; 16_384]; @@ -82,6 +83,14 @@ pub(crate) async fn reader_loop( continue; } + let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap()); + if seq_no != expected_seq { + warn!(seq_no, expected = expected_seq, "ME RPC seq mismatch"); + expected_seq = seq_no.wrapping_add(1); + } else { + expected_seq = expected_seq.wrapping_add(1); + } + let payload = &frame[8..pe]; if payload.len() < 4 { continue; diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index 75e9fba..9905d1d 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -29,7 +29,7 @@ pub struct ConnWriter { } pub struct ConnRegistry { - map: RwLock>>, + map: RwLock>>, writers: RwLock>>>, writer_for_conn: RwLock>, conns_for_writer: RwLock>>, @@ -50,9 +50,9 @@ impl ConnRegistry { } } - pub async fn register(&self) -> (u64, mpsc::UnboundedReceiver) { + pub async fn register(&self) -> (u64, mpsc::Receiver) { let id = self.next_id.fetch_add(1, Ordering::Relaxed); - let (tx, rx) = mpsc::unbounded_channel(); + let (tx, rx) = mpsc::channel(1024); self.map.write().await.insert(id, tx); (id, rx) } @@ -70,7 +70,7 @@ impl ConnRegistry { pub async fn route(&self, id: u64, resp: MeResponse) -> bool { let m = self.map.read().await; if let Some(tx) = m.get(&id) { - tx.send(resp).is_ok() + tx.try_send(resp).is_ok() } else { false } diff --git a/src/transport/middle_proxy/rotation.rs b/src/transport/middle_proxy/rotation.rs new file mode 100644 index 0000000..5313bdb --- /dev/null +++ b/src/transport/middle_proxy/rotation.rs @@ -0,0 +1,37 @@ +use std::sync::Arc; +use std::time::Duration; + +use tracing::{info, warn}; + +use crate::crypto::SecureRandom; + +use super::MePool; + +/// Periodically refresh ME connections to avoid long-lived degradation. +pub async fn me_rotation_task(pool: Arc, rng: Arc, interval: Duration) { + let interval = interval.max(Duration::from_secs(600)); + loop { + tokio::time::sleep(interval).await; + + let candidate = { + let ws = pool.writers.read().await; + ws.get(0).cloned() + }; + + let Some(w) = candidate else { + continue; + }; + + info!(addr = %w.addr, writer_id = w.id, "Rotating ME connection"); + match pool.connect_one(w.addr, rng.as_ref()).await { + Ok(()) => { + // Remove old writer after new one is up. + pool.remove_writer_and_reroute(w.id).await; + } + Err(e) => { + warn!(addr = %w.addr, writer_id = w.id, error = %e, "ME rotation connect failed"); + } + } + } +} + diff --git a/src/transport/middle_proxy/secret.rs b/src/transport/middle_proxy/secret.rs index 9dba939..a9e224d 100644 --- a/src/transport/middle_proxy/secret.rs +++ b/src/transport/middle_proxy/secret.rs @@ -1,6 +1,8 @@ use std::time::Duration; use tracing::{debug, info, warn}; +use std::time::SystemTime; +use httpdate; use crate::error::{ProxyError, Result}; @@ -63,6 +65,23 @@ pub async fn download_proxy_secret() -> Result> { ))); } + if let Some(date) = resp.headers().get(reqwest::header::DATE) { + if let Ok(date_str) = date.to_str() { + if let Ok(server_time) = httpdate::parse_http_date(date_str) { + if let Ok(skew) = SystemTime::now().duration_since(server_time).or_else(|e| { + server_time.duration_since(SystemTime::now()).map_err(|_| e) + }) { + let skew_secs = skew.as_secs(); + if skew_secs > 60 { + warn!(skew_secs, "Time skew >60s detected from proxy-secret Date header"); + } else if skew_secs > 30 { + warn!(skew_secs, "Time skew >30s detected from proxy-secret Date header"); + } + } + } + } + } + let data = resp .bytes() .await diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index ad1c01f..174127d 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -1,6 +1,7 @@ use std::net::SocketAddr; use std::sync::Arc; use std::sync::atomic::Ordering; +use std::time::Duration; use tokio::sync::Mutex; use tracing::{debug, warn}; @@ -38,6 +39,7 @@ impl MePool { our_addr, proto_flags, }; + let mut emergency_attempts = 0; loop { if let Some(current) = self.registry.get_writer(conn_id).await { @@ -71,6 +73,10 @@ impl MePool { let mut candidate_indices = self.candidate_indices_for_dc(&writers_snapshot, target_dc).await; if candidate_indices.is_empty() { // Emergency connect-on-demand + if emergency_attempts >= 3 { + return Err(ProxyError::Proxy("No ME writers available for target DC".into())); + } + emergency_attempts += 1; let map = self.proxy_map_v4.read().await; if let Some(addrs) = map.get(&(target_dc as i32)) { let mut shuffled = addrs.clone(); @@ -82,6 +88,7 @@ impl MePool { break; } } + tokio::time::sleep(Duration::from_millis(100 * emergency_attempts)).await; let ws2 = self.writers.read().await; writers_snapshot = ws2.clone(); drop(ws2);