Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
Alexey
2026-02-18 06:01:52 +03:00
parent 8046381939
commit eb9ac7fae4
9 changed files with 259 additions and 282 deletions

View File

@@ -16,6 +16,7 @@ use tracing::{debug, info, warn};
use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_keys, sha256}; use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_keys, sha256};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use crate::network::IpFamily;
use crate::protocol::constants::{ use crate::protocol::constants::{
ME_CONNECT_TIMEOUT_SECS, ME_HANDSHAKE_TIMEOUT_SECS, RPC_CRYPTO_AES_U32, RPC_HANDSHAKE_ERROR_U32, ME_CONNECT_TIMEOUT_SECS, ME_HANDSHAKE_TIMEOUT_SECS, RPC_CRYPTO_AES_U32, RPC_HANDSHAKE_ERROR_U32,
RPC_HANDSHAKE_U32, RPC_PING_U32, RPC_PONG_U32, RPC_NONCE_U32, RPC_HANDSHAKE_U32, RPC_PING_U32, RPC_PONG_U32, RPC_NONCE_U32,
@@ -101,8 +102,13 @@ impl MePool {
let peer_addr = stream.peer_addr().map_err(ProxyError::Io)?; let peer_addr = stream.peer_addr().map_err(ProxyError::Io)?;
let _ = self.maybe_detect_nat_ip(local_addr.ip()).await; let _ = self.maybe_detect_nat_ip(local_addr.ip()).await;
let family = if local_addr.ip().is_ipv4() {
IpFamily::V4
} else {
IpFamily::V6
};
let reflected = if self.nat_probe { let reflected = if self.nat_probe {
self.maybe_reflect_public_addr().await self.maybe_reflect_public_addr(family).await
} else { } else {
None None
}; };

View File

@@ -1,4 +1,4 @@
use std::collections::HashMap; use std::collections::{HashMap, HashSet};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
@@ -7,107 +7,84 @@ use tracing::{debug, info, warn};
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::network::IpFamily;
use super::MePool; use super::MePool;
pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, _min_connections: usize) { pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, _min_connections: usize) {
let mut backoff: HashMap<i32, u64> = HashMap::new(); let mut backoff: HashMap<(i32, IpFamily), u64> = HashMap::new();
let mut last_attempt: HashMap<i32, Instant> = HashMap::new(); let mut last_attempt: HashMap<(i32, IpFamily), Instant> = HashMap::new();
loop { loop {
tokio::time::sleep(Duration::from_secs(30)).await; tokio::time::sleep(Duration::from_secs(30)).await;
// Per-DC coverage check check_family(IpFamily::V4, &pool, &rng, &mut backoff, &mut last_attempt).await;
let map = pool.proxy_map_v4.read().await.clone(); check_family(IpFamily::V6, &pool, &rng, &mut backoff, &mut last_attempt).await;
let writer_addrs: std::collections::HashSet<SocketAddr> = pool }
.writers }
.read()
.await
.iter()
.map(|w| w.addr)
.collect();
for (dc, addrs) in map.iter() { async fn check_family(
let dc_addrs: Vec<SocketAddr> = addrs family: IpFamily,
.iter() pool: &Arc<MePool>,
.map(|(ip, port)| SocketAddr::new(*ip, *port)) rng: &Arc<SecureRandom>,
.collect(); backoff: &mut HashMap<(i32, IpFamily), u64>,
let has_coverage = dc_addrs.iter().any(|a| writer_addrs.contains(a)); last_attempt: &mut HashMap<(i32, IpFamily), Instant>,
if !has_coverage { ) {
let delay = *backoff.get(dc).unwrap_or(&30); let enabled = match family {
let now = Instant::now(); IpFamily::V4 => pool.decision.ipv4_me,
if let Some(last) = last_attempt.get(dc) { IpFamily::V6 => pool.decision.ipv6_me,
if now.duration_since(*last).as_secs() < delay { };
continue; if !enabled {
} return;
} }
warn!(dc = %dc, delay, "DC has no ME coverage, reconnecting...");
let mut shuffled = dc_addrs.clone(); let map = match family {
shuffled.shuffle(&mut rand::rng()); IpFamily::V4 => pool.proxy_map_v4.read().await.clone(),
let mut reconnected = false; IpFamily::V6 => pool.proxy_map_v6.read().await.clone(),
for addr in shuffled { };
match pool.connect_one(addr, &rng).await { let writer_addrs: HashSet<SocketAddr> = pool
Ok(()) => { .writers
info!(%addr, dc = %dc, "ME reconnected for DC coverage"); .read()
backoff.insert(*dc, 30); .await
last_attempt.insert(*dc, now); .iter()
reconnected = true; .map(|w| w.addr)
break; .collect();
}
Err(e) => debug!(%addr, dc = %dc, error = %e, "ME reconnect failed"), for (dc, addrs) in map.iter() {
} let dc_addrs: Vec<SocketAddr> = addrs
} .iter()
if !reconnected { .map(|(ip, port)| SocketAddr::new(*ip, *port))
let next = (*backoff.get(dc).unwrap_or(&30)).saturating_mul(2).min(300); .collect();
backoff.insert(*dc, next); let has_coverage = dc_addrs.iter().any(|a| writer_addrs.contains(a));
last_attempt.insert(*dc, now); if has_coverage {
} continue;
}
let key = (*dc, family);
let delay = *backoff.get(&key).unwrap_or(&30);
let now = Instant::now();
if let Some(last) = last_attempt.get(&key) {
if now.duration_since(*last).as_secs() < delay {
continue;
} }
} }
warn!(dc = %dc, delay, ?family, "DC has no ME coverage, reconnecting...");
// IPv6 coverage check (if available) let mut shuffled = dc_addrs.clone();
let map_v6 = pool.proxy_map_v6.read().await.clone(); shuffled.shuffle(&mut rand::rng());
let writer_addrs_v6: std::collections::HashSet<SocketAddr> = pool let mut reconnected = false;
.writers for addr in shuffled {
.read() match pool.connect_one(addr, rng.as_ref()).await {
.await Ok(()) => {
.iter() info!(%addr, dc = %dc, ?family, "ME reconnected for DC coverage");
.map(|w| w.addr) backoff.insert(key, 30);
.collect(); last_attempt.insert(key, now);
for (dc, addrs) in map_v6.iter() { reconnected = true;
let dc_addrs: Vec<SocketAddr> = addrs break;
.iter()
.map(|(ip, port)| SocketAddr::new(*ip, *port))
.collect();
let has_coverage = dc_addrs.iter().any(|a| writer_addrs_v6.contains(a));
if !has_coverage {
let delay = *backoff.get(dc).unwrap_or(&30);
let now = Instant::now();
if let Some(last) = last_attempt.get(dc) {
if now.duration_since(*last).as_secs() < delay {
continue;
}
}
warn!(dc = %dc, delay, "IPv6 DC has no ME coverage, reconnecting...");
let mut shuffled = dc_addrs.clone();
shuffled.shuffle(&mut rand::rng());
let mut reconnected = false;
for addr in shuffled {
match pool.connect_one(addr, &rng).await {
Ok(()) => {
info!(%addr, dc = %dc, "ME reconnected for IPv6 DC coverage");
backoff.insert(*dc, 30);
last_attempt.insert(*dc, now);
reconnected = true;
break;
}
Err(e) => debug!(%addr, dc = %dc, error = %e, "ME reconnect failed (IPv6)"),
}
}
if !reconnected {
let next = (*backoff.get(dc).unwrap_or(&30)).saturating_mul(2).min(300);
backoff.insert(*dc, next);
last_attempt.insert(*dc, now);
} }
Err(e) => debug!(%addr, dc = %dc, error = %e, ?family, "ME reconnect failed"),
} }
} }
if !reconnected {
let next = (*backoff.get(&key).unwrap_or(&30)).saturating_mul(2).min(300);
backoff.insert(key, next);
last_attempt.insert(key, now);
}
} }
} }

View File

@@ -19,7 +19,7 @@ use bytes::Bytes;
pub use health::me_health_monitor; pub use health::me_health_monitor;
pub use ping::{run_me_ping, format_sample_line, MePingReport, MePingSample, MePingFamily}; pub use ping::{run_me_ping, format_sample_line, MePingReport, MePingSample, MePingFamily};
pub use pool::MePool; pub use pool::MePool;
pub use pool_nat::{stun_probe, detect_public_ip, StunProbeResult}; pub use pool_nat::{stun_probe, detect_public_ip};
pub use registry::ConnRegistry; pub use registry::ConnRegistry;
pub use secret::fetch_proxy_secret; pub use secret::fetch_proxy_secret;
pub use config_updater::{fetch_proxy_config, me_config_updater}; pub use config_updater::{fetch_proxy_config, me_config_updater};

View File

@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::sync::Arc; use std::sync::Arc;
@@ -92,8 +93,16 @@ mod tests {
pub async fn run_me_ping(pool: &Arc<MePool>, rng: &SecureRandom) -> Vec<MePingReport> { pub async fn run_me_ping(pool: &Arc<MePool>, rng: &SecureRandom) -> Vec<MePingReport> {
let mut reports = Vec::new(); let mut reports = Vec::new();
let v4_map = pool.proxy_map_v4.read().await.clone(); let v4_map = if pool.decision.ipv4_me {
let v6_map = pool.proxy_map_v6.read().await.clone(); pool.proxy_map_v4.read().await.clone()
} else {
HashMap::new()
};
let v6_map = if pool.decision.ipv6_me {
pool.proxy_map_v6.read().await.clone()
} else {
HashMap::new()
};
let mut grouped: Vec<(MePingFamily, i32, Vec<(IpAddr, u16)>)> = Vec::new(); let mut grouped: Vec<(MePingFamily, i32, Vec<(IpAddr, u16)>)> = Vec::new();
for (dc, addrs) in v4_map { for (dc, addrs) in v4_map {

View File

@@ -12,6 +12,8 @@ use std::time::Duration;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use crate::network::probe::NetworkDecision;
use crate::network::IpFamily;
use crate::protocol::constants::*; use crate::protocol::constants::*;
use super::ConnRegistry; use super::ConnRegistry;
@@ -36,6 +38,8 @@ pub struct MePool {
pub(super) registry: Arc<ConnRegistry>, pub(super) registry: Arc<ConnRegistry>,
pub(super) writers: Arc<RwLock<Vec<MeWriter>>>, pub(super) writers: Arc<RwLock<Vec<MeWriter>>>,
pub(super) rr: AtomicU64, pub(super) rr: AtomicU64,
pub(super) decision: NetworkDecision,
pub(super) rng: Arc<SecureRandom>,
pub(super) proxy_tag: Option<Vec<u8>>, pub(super) proxy_tag: Option<Vec<u8>>,
pub(super) proxy_secret: Arc<RwLock<Vec<u8>>>, pub(super) proxy_secret: Arc<RwLock<Vec<u8>>>,
pub(super) nat_ip_cfg: Option<IpAddr>, pub(super) nat_ip_cfg: Option<IpAddr>,
@@ -48,10 +52,16 @@ pub struct MePool {
pub(super) next_writer_id: AtomicU64, pub(super) next_writer_id: AtomicU64,
pub(super) ping_tracker: Arc<Mutex<HashMap<i64, (std::time::Instant, u64)>>>, pub(super) ping_tracker: Arc<Mutex<HashMap<i64, (std::time::Instant, u64)>>>,
pub(super) rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>, pub(super) rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
pub(super) nat_reflection_cache: Arc<Mutex<Option<(std::time::Instant, std::net::SocketAddr)>>>, pub(super) nat_reflection_cache: Arc<Mutex<NatReflectionCache>>,
pool_size: usize, pool_size: usize,
} }
#[derive(Debug, Default)]
pub struct NatReflectionCache {
pub v4: Option<(std::time::Instant, std::net::SocketAddr)>,
pub v6: Option<(std::time::Instant, std::net::SocketAddr)>,
}
impl MePool { impl MePool {
pub fn new( pub fn new(
proxy_tag: Option<Vec<u8>>, proxy_tag: Option<Vec<u8>>,
@@ -62,11 +72,15 @@ impl MePool {
proxy_map_v4: HashMap<i32, Vec<(IpAddr, u16)>>, proxy_map_v4: HashMap<i32, Vec<(IpAddr, u16)>>,
proxy_map_v6: HashMap<i32, Vec<(IpAddr, u16)>>, proxy_map_v6: HashMap<i32, Vec<(IpAddr, u16)>>,
default_dc: Option<i32>, default_dc: Option<i32>,
decision: NetworkDecision,
rng: Arc<SecureRandom>,
) -> Arc<Self> { ) -> Arc<Self> {
Arc::new(Self { Arc::new(Self {
registry: Arc::new(ConnRegistry::new()), registry: Arc::new(ConnRegistry::new()),
writers: Arc::new(RwLock::new(Vec::new())), writers: Arc::new(RwLock::new(Vec::new())),
rr: AtomicU64::new(0), rr: AtomicU64::new(0),
decision,
rng,
proxy_tag, proxy_tag,
proxy_secret: Arc::new(RwLock::new(proxy_secret)), proxy_secret: Arc::new(RwLock::new(proxy_secret)),
nat_ip_cfg: nat_ip, nat_ip_cfg: nat_ip,
@@ -80,7 +94,7 @@ impl MePool {
next_writer_id: AtomicU64::new(1), next_writer_id: AtomicU64::new(1),
ping_tracker: Arc::new(Mutex::new(HashMap::new())), ping_tracker: Arc::new(Mutex::new(HashMap::new())),
rtt_stats: Arc::new(Mutex::new(HashMap::new())), rtt_stats: Arc::new(Mutex::new(HashMap::new())),
nat_reflection_cache: Arc::new(Mutex::new(None)), nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())),
}) })
} }
@@ -103,29 +117,30 @@ impl MePool {
pub async fn reconcile_connections(self: &Arc<Self>, rng: &SecureRandom) { pub async fn reconcile_connections(self: &Arc<Self>, rng: &SecureRandom) {
use std::collections::HashSet; use std::collections::HashSet;
let map = self.proxy_map_v4.read().await.clone();
let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map
.iter()
.map(|(dc, addrs)| (*dc, addrs.clone()))
.collect();
let writers = self.writers.read().await; let writers = self.writers.read().await;
let current: HashSet<SocketAddr> = writers.iter().map(|w| w.addr).collect(); let current: HashSet<SocketAddr> = writers.iter().map(|w| w.addr).collect();
drop(writers); drop(writers);
for (_dc, addrs) in map.iter() { for family in self.family_order() {
let dc_addrs: Vec<SocketAddr> = addrs let map = self.proxy_map_for_family(family).await;
.iter() for (_dc, addrs) in map.iter() {
.map(|(ip, port)| SocketAddr::new(*ip, *port)) let dc_addrs: Vec<SocketAddr> = addrs
.collect(); .iter()
if !dc_addrs.iter().any(|a| current.contains(a)) { .map(|(ip, port)| SocketAddr::new(*ip, *port))
let mut shuffled = dc_addrs.clone(); .collect();
shuffled.shuffle(&mut rand::rng()); if !dc_addrs.iter().any(|a| current.contains(a)) {
for addr in shuffled { let mut shuffled = dc_addrs.clone();
if self.connect_one(addr, rng).await.is_ok() { shuffled.shuffle(&mut rand::rng());
break; for addr in shuffled {
if self.connect_one(addr, rng).await.is_ok() {
break;
}
} }
} }
} }
if !self.decision.effective_multipath && !current.is_empty() {
break;
}
} }
} }
@@ -181,47 +196,82 @@ impl MePool {
} }
} }
pub(super) fn family_order(&self) -> Vec<IpFamily> {
let mut order = Vec::new();
if self.decision.prefer_ipv6() {
if self.decision.ipv6_me {
order.push(IpFamily::V6);
}
if self.decision.ipv4_me {
order.push(IpFamily::V4);
}
} else {
if self.decision.ipv4_me {
order.push(IpFamily::V4);
}
if self.decision.ipv6_me {
order.push(IpFamily::V6);
}
}
order
}
async fn proxy_map_for_family(&self, family: IpFamily) -> HashMap<i32, Vec<(IpAddr, u16)>> {
match family {
IpFamily::V4 => self.proxy_map_v4.read().await.clone(),
IpFamily::V6 => self.proxy_map_v6.read().await.clone(),
}
}
pub async fn init(self: &Arc<Self>, pool_size: usize, rng: &Arc<SecureRandom>) -> Result<()> { pub async fn init(self: &Arc<Self>, pool_size: usize, rng: &Arc<SecureRandom>) -> Result<()> {
let map = self.proxy_map_v4.read().await.clone(); let family_order = self.family_order();
let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map
.iter()
.map(|(dc, addrs)| (*dc, addrs.clone()))
.collect();
let ks = self.key_selector().await; let ks = self.key_selector().await;
info!( info!(
me_servers = map.len(), me_servers = self.proxy_map_v4.read().await.len(),
pool_size, pool_size,
key_selector = format_args!("0x{ks:08x}"), key_selector = format_args!("0x{ks:08x}"),
secret_len = self.proxy_secret.read().await.len(), secret_len = self.proxy_secret.read().await.len(),
"Initializing ME pool" "Initializing ME pool"
); );
// Ensure at least one connection per DC; run DCs in parallel. for family in family_order {
let mut join = tokio::task::JoinSet::new(); let map = self.proxy_map_for_family(family).await;
for (dc, addrs) in dc_addrs.iter().cloned() { let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map
if addrs.is_empty() { .iter()
continue; .map(|(dc, addrs)| (*dc, addrs.clone()))
} .collect();
let pool = Arc::clone(self);
let rng_clone = Arc::clone(rng);
join.spawn(async move {
pool.connect_primary_for_dc(dc, addrs, rng_clone).await;
});
}
while let Some(_res) = join.join_next().await {}
// Additional connections up to pool_size total (round-robin across DCs) // Ensure at least one connection per DC; run DCs in parallel.
for (dc, addrs) in dc_addrs.iter() { let mut join = tokio::task::JoinSet::new();
for (ip, port) in addrs { for (dc, addrs) in dc_addrs.iter().cloned() {
if addrs.is_empty() {
continue;
}
let pool = Arc::clone(self);
let rng_clone = Arc::clone(rng);
join.spawn(async move {
pool.connect_primary_for_dc(dc, addrs, rng_clone).await;
});
}
while let Some(_res) = join.join_next().await {}
// Additional connections up to pool_size total (round-robin across DCs)
for (dc, addrs) in dc_addrs.iter() {
for (ip, port) in addrs {
if self.connection_count() >= pool_size {
break;
}
let addr = SocketAddr::new(*ip, *port);
if let Err(e) = self.connect_one(addr, rng.as_ref()).await {
debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed");
}
}
if self.connection_count() >= pool_size { if self.connection_count() >= pool_size {
break; break;
} }
let addr = SocketAddr::new(*ip, *port);
if let Err(e) = self.connect_one(addr, rng.as_ref()).await {
debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed");
}
} }
if self.connection_count() >= pool_size {
if !self.decision.effective_multipath && self.connection_count() > 0 {
break; break;
} }
} }
@@ -309,14 +359,15 @@ impl MePool {
} }
_ = tokio::time::sleep(Duration::from_secs(wait)) => {} _ = tokio::time::sleep(Duration::from_secs(wait)) => {}
} }
let sent_id = ping_id;
let mut p = Vec::with_capacity(12); let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_PING_U32.to_le_bytes()); p.extend_from_slice(&RPC_PING_U32.to_le_bytes());
p.extend_from_slice(&ping_id.to_le_bytes()); p.extend_from_slice(&sent_id.to_le_bytes());
ping_id = ping_id.wrapping_add(1);
{ {
let mut tracker = ping_tracker_ping.lock().await; let mut tracker = ping_tracker_ping.lock().await;
tracker.insert(ping_id, (std::time::Instant::now(), writer_id)); tracker.insert(sent_id, (std::time::Instant::now(), writer_id));
} }
ping_id = ping_id.wrapping_add(1);
if let Err(e) = rpc_w_ping.lock().await.send(&p).await { if let Err(e) = rpc_w_ping.lock().await.send(&p).await {
debug!(error = %e, "Active ME ping failed, removing dead writer"); debug!(error = %e, "Active ME ping failed, removing dead writer");
cancel_ping.cancel(); cancel_ping.cancel();

View File

@@ -4,19 +4,14 @@ use std::time::Duration;
use tracing::{info, warn}; use tracing::{info, warn};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use crate::network::probe::is_bogon;
use crate::network::stun::{stun_probe_dual, IpFamily, StunProbeResult};
use super::MePool; use super::MePool;
use std::time::Instant; use std::time::Instant;
pub async fn stun_probe(stun_addr: Option<String>) -> Result<crate::network::stun::DualStunResult> {
#[derive(Debug, Clone, Copy)]
pub struct StunProbeResult {
pub local_addr: std::net::SocketAddr,
pub reflected_addr: std::net::SocketAddr,
}
pub async fn stun_probe(stun_addr: Option<String>) -> Result<Option<StunProbeResult>> {
let stun_addr = stun_addr.unwrap_or_else(|| "stun.l.google.com:19302".to_string()); let stun_addr = stun_addr.unwrap_or_else(|| "stun.l.google.com:19302".to_string());
fetch_stun_binding(&stun_addr).await stun_probe_dual(&stun_addr).await
} }
pub async fn detect_public_ip() -> Option<IpAddr> { pub async fn detect_public_ip() -> Option<IpAddr> {
@@ -35,7 +30,7 @@ impl MePool {
match (ip, nat_ip) { match (ip, nat_ip) {
(IpAddr::V4(src), IpAddr::V4(dst)) (IpAddr::V4(src), IpAddr::V4(dst))
if is_privateish(IpAddr::V4(src)) if is_bogon(IpAddr::V4(src))
|| src.is_loopback() || src.is_loopback()
|| src.is_unspecified() => || src.is_unspecified() =>
{ {
@@ -55,7 +50,7 @@ impl MePool {
) -> std::net::SocketAddr { ) -> std::net::SocketAddr {
let ip = if let Some(r) = reflected { let ip = if let Some(r) = reflected {
// Use reflected IP (not port) only when local address is non-public. // Use reflected IP (not port) only when local address is non-public.
if is_privateish(addr.ip()) || addr.ip().is_loopback() || addr.ip().is_unspecified() { if is_bogon(addr.ip()) || addr.ip().is_loopback() || addr.ip().is_unspecified() {
r.ip() r.ip()
} else { } else {
self.translate_ip_for_nat(addr.ip()) self.translate_ip_for_nat(addr.ip())
@@ -73,7 +68,7 @@ impl MePool {
return self.nat_ip_cfg; return self.nat_ip_cfg;
} }
if !(is_privateish(local_ip) || local_ip.is_loopback() || local_ip.is_unspecified()) { if !(is_bogon(local_ip) || local_ip.is_loopback() || local_ip.is_unspecified()) {
return None; return None;
} }
@@ -98,12 +93,19 @@ impl MePool {
} }
} }
pub(super) async fn maybe_reflect_public_addr(&self) -> Option<std::net::SocketAddr> { pub(super) async fn maybe_reflect_public_addr(
&self,
family: IpFamily,
) -> Option<std::net::SocketAddr> {
const STUN_CACHE_TTL: Duration = Duration::from_secs(600); const STUN_CACHE_TTL: Duration = Duration::from_secs(600);
if let Ok(mut cache) = self.nat_reflection_cache.try_lock() { if let Ok(mut cache) = self.nat_reflection_cache.try_lock() {
if let Some((ts, addr)) = *cache { let slot = match family {
IpFamily::V4 => &mut cache.v4,
IpFamily::V6 => &mut cache.v6,
};
if let Some((ts, addr)) = slot {
if ts.elapsed() < STUN_CACHE_TTL { if ts.elapsed() < STUN_CACHE_TTL {
return Some(addr); return Some(*addr);
} }
} }
} }
@@ -112,12 +114,20 @@ impl MePool {
.nat_stun .nat_stun
.clone() .clone()
.unwrap_or_else(|| "stun.l.google.com:19302".to_string()); .unwrap_or_else(|| "stun.l.google.com:19302".to_string());
match fetch_stun_binding(&stun_addr).await { match stun_probe_dual(&stun_addr).await {
Ok(sa) => { Ok(res) => {
if let Some(result) = sa { let picked: Option<StunProbeResult> = match family {
info!(local = %result.local_addr, reflected = %result.reflected_addr, "NAT probe: reflected address"); IpFamily::V4 => res.v4,
IpFamily::V6 => res.v6,
};
if let Some(result) = picked {
info!(local = %result.local_addr, reflected = %result.reflected_addr, family = ?family, "NAT probe: reflected address");
if let Ok(mut cache) = self.nat_reflection_cache.try_lock() { if let Ok(mut cache) = self.nat_reflection_cache.try_lock() {
*cache = Some((Instant::now(), result.reflected_addr)); let slot = match family {
IpFamily::V4 => &mut cache.v4,
IpFamily::V6 => &mut cache.v6,
};
*slot = Some((Instant::now(), result.reflected_addr));
} }
Some(result.reflected_addr) Some(result.reflected_addr)
} else { } else {
@@ -158,98 +168,3 @@ async fn fetch_public_ipv4_once(url: &str) -> Result<Option<Ipv4Addr>> {
let ip = text.trim().parse().ok(); let ip = text.trim().parse().ok();
Ok(ip) Ok(ip)
} }
async fn fetch_stun_binding(stun_addr: &str) -> Result<Option<StunProbeResult>> {
use rand::RngCore;
use tokio::net::UdpSocket;
let socket = UdpSocket::bind("0.0.0.0:0")
.await
.map_err(|e| ProxyError::Proxy(format!("STUN bind failed: {e}")))?;
socket
.connect(stun_addr)
.await
.map_err(|e| ProxyError::Proxy(format!("STUN connect failed: {e}")))?;
// Build minimal Binding Request.
let mut req = vec![0u8; 20];
req[0..2].copy_from_slice(&0x0001u16.to_be_bytes()); // Binding Request
req[2..4].copy_from_slice(&0u16.to_be_bytes()); // length
req[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes()); // magic cookie
rand::rng().fill_bytes(&mut req[8..20]);
socket
.send(&req)
.await
.map_err(|e| ProxyError::Proxy(format!("STUN send failed: {e}")))?;
let mut buf = [0u8; 128];
let n = socket
.recv(&mut buf)
.await
.map_err(|e| ProxyError::Proxy(format!("STUN recv failed: {e}")))?;
if n < 20 {
return Ok(None);
}
// Parse attributes.
let mut idx = 20;
while idx + 4 <= n {
let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().unwrap());
let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().unwrap()) as usize;
idx += 4;
if idx + alen > n {
break;
}
match atype {
0x0020 /* XOR-MAPPED-ADDRESS */ | 0x0001 /* MAPPED-ADDRESS */ => {
if alen < 8 {
break;
}
let family = buf[idx + 1];
if family != 0x01 {
// only IPv4 supported here
break;
}
let port_bytes = [buf[idx + 2], buf[idx + 3]];
let ip_bytes = [buf[idx + 4], buf[idx + 5], buf[idx + 6], buf[idx + 7]];
let (port, ip) = if atype == 0x0020 {
let magic = 0x2112A442u32.to_be_bytes();
let port = u16::from_be_bytes(port_bytes) ^ ((magic[0] as u16) << 8 | magic[1] as u16);
let ip = [
ip_bytes[0] ^ magic[0],
ip_bytes[1] ^ magic[1],
ip_bytes[2] ^ magic[2],
ip_bytes[3] ^ magic[3],
];
(port, ip)
} else {
(u16::from_be_bytes(port_bytes), ip_bytes)
};
let reflected = std::net::SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3])),
port,
);
let local_addr = socket.local_addr().map_err(|e| {
ProxyError::Proxy(format!("STUN local_addr failed: {e}"))
})?;
return Ok(Some(StunProbeResult {
local_addr,
reflected_addr: reflected,
}));
}
_ => {}
}
idx += (alen + 3) & !3; // 4-byte alignment
}
Ok(None)
}
fn is_privateish(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => v4.is_private() || v4.is_link_local(),
IpAddr::V6(v6) => v6.is_unique_local(),
}
}

View File

@@ -152,6 +152,9 @@ pub(crate) async fn reader_loop(
entry.1 = entry.1 * 0.8 + rtt * 0.2; entry.1 = entry.1 * 0.8 + rtt * 0.2;
if rtt < entry.0 { if rtt < entry.0 {
entry.0 = rtt; entry.0 = rtt;
} else {
// allow slow baseline drift upward to avoid stale minimum
entry.0 = entry.0 * 0.99 + rtt * 0.01;
} }
let degraded_now = entry.1 > entry.0 * 2.0; let degraded_now = entry.1 > entry.0 * 2.0;
degraded.store(degraded_now, Ordering::Relaxed); degraded.store(degraded_now, Ordering::Relaxed);

View File

@@ -1,4 +1,5 @@
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::time::Duration; use std::time::Duration;
use tracing::{info, warn}; use tracing::{info, warn};
@@ -15,7 +16,12 @@ pub async fn me_rotation_task(pool: Arc<MePool>, rng: Arc<SecureRandom>, interva
let candidate = { let candidate = {
let ws = pool.writers.read().await; let ws = pool.writers.read().await;
ws.get(0).cloned() if ws.is_empty() {
None
} else {
let idx = (pool.rr.load(std::sync::atomic::Ordering::Relaxed) as usize) % ws.len();
ws.get(idx).cloned()
}
}; };
let Some(w) = candidate else { let Some(w) = candidate else {
@@ -34,4 +40,3 @@ pub async fn me_rotation_task(pool: Arc<MePool>, rng: Arc<SecureRandom>, interva
} }
} }
} }

View File

@@ -3,15 +3,14 @@ use std::sync::Arc;
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use std::time::Duration; use std::time::Duration;
use tokio::sync::Mutex;
use tracing::{debug, warn}; use tracing::{debug, warn};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use crate::network::IpFamily;
use crate::protocol::constants::RPC_CLOSE_EXT_U32; use crate::protocol::constants::RPC_CLOSE_EXT_U32;
use super::MePool; use super::MePool;
use super::wire::build_proxy_req_payload; use super::wire::build_proxy_req_payload;
use crate::crypto::SecureRandom;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use super::registry::ConnMeta; use super::registry::ConnMeta;
@@ -84,7 +83,7 @@ impl MePool {
drop(map); drop(map);
for (ip, port) in shuffled { for (ip, port) in shuffled {
let addr = SocketAddr::new(ip, port); let addr = SocketAddr::new(ip, port);
if self.connect_one(addr, &SecureRandom::new()).await.is_ok() { if self.connect_one(addr, self.rng.as_ref()).await.is_ok() {
break; break;
} }
} }
@@ -173,32 +172,44 @@ impl MePool {
writers: &[super::pool::MeWriter], writers: &[super::pool::MeWriter],
target_dc: i16, target_dc: i16,
) -> Vec<usize> { ) -> Vec<usize> {
let mut preferred = Vec::<SocketAddr>::new();
let key = target_dc as i32; let key = target_dc as i32;
let map = self.proxy_map_v4.read().await; let mut preferred = Vec::<SocketAddr>::new();
if let Some(v) = map.get(&key) { for family in self.family_order() {
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port))); let map_guard = match family {
} IpFamily::V4 => self.proxy_map_v4.read().await,
if preferred.is_empty() { IpFamily::V6 => self.proxy_map_v6.read().await,
let abs = key.abs(); };
if let Some(v) = map.get(&abs) {
if let Some(v) = map_guard.get(&key) {
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port))); preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port)));
} }
} if preferred.is_empty() {
if preferred.is_empty() { let abs = key.abs();
let abs = key.abs(); if let Some(v) = map_guard.get(&abs) {
if let Some(v) = map.get(&-abs) {
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port)));
}
}
if preferred.is_empty() {
let def = self.default_dc.load(Ordering::Relaxed);
if def != 0 {
if let Some(v) = map.get(&def) {
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port))); preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port)));
} }
} }
if preferred.is_empty() {
let abs = key.abs();
if let Some(v) = map_guard.get(&-abs) {
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port)));
}
}
if preferred.is_empty() {
let def = self.default_dc.load(Ordering::Relaxed);
if def != 0 {
if let Some(v) = map_guard.get(&def) {
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port)));
}
}
}
drop(map_guard);
if !preferred.is_empty() && !self.decision.effective_multipath {
break;
}
} }
if preferred.is_empty() { if preferred.is_empty() {