Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
Alexey
2026-02-18 06:01:52 +03:00
parent 8046381939
commit eb9ac7fae4
9 changed files with 259 additions and 282 deletions

View File

@@ -12,6 +12,8 @@ use std::time::Duration;
use crate::crypto::SecureRandom;
use crate::error::{ProxyError, Result};
use crate::network::probe::NetworkDecision;
use crate::network::IpFamily;
use crate::protocol::constants::*;
use super::ConnRegistry;
@@ -36,6 +38,8 @@ pub struct MePool {
pub(super) registry: Arc<ConnRegistry>,
pub(super) writers: Arc<RwLock<Vec<MeWriter>>>,
pub(super) rr: AtomicU64,
pub(super) decision: NetworkDecision,
pub(super) rng: Arc<SecureRandom>,
pub(super) proxy_tag: Option<Vec<u8>>,
pub(super) proxy_secret: Arc<RwLock<Vec<u8>>>,
pub(super) nat_ip_cfg: Option<IpAddr>,
@@ -48,10 +52,16 @@ pub struct MePool {
pub(super) next_writer_id: AtomicU64,
pub(super) ping_tracker: Arc<Mutex<HashMap<i64, (std::time::Instant, u64)>>>,
pub(super) rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
pub(super) nat_reflection_cache: Arc<Mutex<Option<(std::time::Instant, std::net::SocketAddr)>>>,
pub(super) nat_reflection_cache: Arc<Mutex<NatReflectionCache>>,
pool_size: usize,
}
#[derive(Debug, Default)]
pub struct NatReflectionCache {
pub v4: Option<(std::time::Instant, std::net::SocketAddr)>,
pub v6: Option<(std::time::Instant, std::net::SocketAddr)>,
}
impl MePool {
pub fn new(
proxy_tag: Option<Vec<u8>>,
@@ -62,11 +72,15 @@ impl MePool {
proxy_map_v4: HashMap<i32, Vec<(IpAddr, u16)>>,
proxy_map_v6: HashMap<i32, Vec<(IpAddr, u16)>>,
default_dc: Option<i32>,
decision: NetworkDecision,
rng: Arc<SecureRandom>,
) -> Arc<Self> {
Arc::new(Self {
registry: Arc::new(ConnRegistry::new()),
writers: Arc::new(RwLock::new(Vec::new())),
rr: AtomicU64::new(0),
decision,
rng,
proxy_tag,
proxy_secret: Arc::new(RwLock::new(proxy_secret)),
nat_ip_cfg: nat_ip,
@@ -80,7 +94,7 @@ impl MePool {
next_writer_id: AtomicU64::new(1),
ping_tracker: Arc::new(Mutex::new(HashMap::new())),
rtt_stats: Arc::new(Mutex::new(HashMap::new())),
nat_reflection_cache: Arc::new(Mutex::new(None)),
nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())),
})
}
@@ -103,29 +117,30 @@ impl MePool {
pub async fn reconcile_connections(self: &Arc<Self>, rng: &SecureRandom) {
use std::collections::HashSet;
let map = self.proxy_map_v4.read().await.clone();
let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map
.iter()
.map(|(dc, addrs)| (*dc, addrs.clone()))
.collect();
let writers = self.writers.read().await;
let current: HashSet<SocketAddr> = writers.iter().map(|w| w.addr).collect();
drop(writers);
for (_dc, addrs) in map.iter() {
let dc_addrs: Vec<SocketAddr> = addrs
.iter()
.map(|(ip, port)| SocketAddr::new(*ip, *port))
.collect();
if !dc_addrs.iter().any(|a| current.contains(a)) {
let mut shuffled = dc_addrs.clone();
shuffled.shuffle(&mut rand::rng());
for addr in shuffled {
if self.connect_one(addr, rng).await.is_ok() {
break;
for family in self.family_order() {
let map = self.proxy_map_for_family(family).await;
for (_dc, addrs) in map.iter() {
let dc_addrs: Vec<SocketAddr> = addrs
.iter()
.map(|(ip, port)| SocketAddr::new(*ip, *port))
.collect();
if !dc_addrs.iter().any(|a| current.contains(a)) {
let mut shuffled = dc_addrs.clone();
shuffled.shuffle(&mut rand::rng());
for addr in shuffled {
if self.connect_one(addr, rng).await.is_ok() {
break;
}
}
}
}
if !self.decision.effective_multipath && !current.is_empty() {
break;
}
}
}
@@ -181,47 +196,82 @@ impl MePool {
}
}
pub(super) fn family_order(&self) -> Vec<IpFamily> {
let mut order = Vec::new();
if self.decision.prefer_ipv6() {
if self.decision.ipv6_me {
order.push(IpFamily::V6);
}
if self.decision.ipv4_me {
order.push(IpFamily::V4);
}
} else {
if self.decision.ipv4_me {
order.push(IpFamily::V4);
}
if self.decision.ipv6_me {
order.push(IpFamily::V6);
}
}
order
}
async fn proxy_map_for_family(&self, family: IpFamily) -> HashMap<i32, Vec<(IpAddr, u16)>> {
match family {
IpFamily::V4 => self.proxy_map_v4.read().await.clone(),
IpFamily::V6 => self.proxy_map_v6.read().await.clone(),
}
}
pub async fn init(self: &Arc<Self>, pool_size: usize, rng: &Arc<SecureRandom>) -> Result<()> {
let map = self.proxy_map_v4.read().await.clone();
let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map
.iter()
.map(|(dc, addrs)| (*dc, addrs.clone()))
.collect();
let family_order = self.family_order();
let ks = self.key_selector().await;
info!(
me_servers = map.len(),
me_servers = self.proxy_map_v4.read().await.len(),
pool_size,
key_selector = format_args!("0x{ks:08x}"),
secret_len = self.proxy_secret.read().await.len(),
"Initializing ME pool"
);
// Ensure at least one connection per DC; run DCs in parallel.
let mut join = tokio::task::JoinSet::new();
for (dc, addrs) in dc_addrs.iter().cloned() {
if addrs.is_empty() {
continue;
}
let pool = Arc::clone(self);
let rng_clone = Arc::clone(rng);
join.spawn(async move {
pool.connect_primary_for_dc(dc, addrs, rng_clone).await;
});
}
while let Some(_res) = join.join_next().await {}
for family in family_order {
let map = self.proxy_map_for_family(family).await;
let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map
.iter()
.map(|(dc, addrs)| (*dc, addrs.clone()))
.collect();
// Additional connections up to pool_size total (round-robin across DCs)
for (dc, addrs) in dc_addrs.iter() {
for (ip, port) in addrs {
// Ensure at least one connection per DC; run DCs in parallel.
let mut join = tokio::task::JoinSet::new();
for (dc, addrs) in dc_addrs.iter().cloned() {
if addrs.is_empty() {
continue;
}
let pool = Arc::clone(self);
let rng_clone = Arc::clone(rng);
join.spawn(async move {
pool.connect_primary_for_dc(dc, addrs, rng_clone).await;
});
}
while let Some(_res) = join.join_next().await {}
// Additional connections up to pool_size total (round-robin across DCs)
for (dc, addrs) in dc_addrs.iter() {
for (ip, port) in addrs {
if self.connection_count() >= pool_size {
break;
}
let addr = SocketAddr::new(*ip, *port);
if let Err(e) = self.connect_one(addr, rng.as_ref()).await {
debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed");
}
}
if self.connection_count() >= pool_size {
break;
}
let addr = SocketAddr::new(*ip, *port);
if let Err(e) = self.connect_one(addr, rng.as_ref()).await {
debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed");
}
}
if self.connection_count() >= pool_size {
if !self.decision.effective_multipath && self.connection_count() > 0 {
break;
}
}
@@ -309,14 +359,15 @@ impl MePool {
}
_ = tokio::time::sleep(Duration::from_secs(wait)) => {}
}
let sent_id = ping_id;
let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_PING_U32.to_le_bytes());
p.extend_from_slice(&ping_id.to_le_bytes());
ping_id = ping_id.wrapping_add(1);
p.extend_from_slice(&sent_id.to_le_bytes());
{
let mut tracker = ping_tracker_ping.lock().await;
tracker.insert(ping_id, (std::time::Instant::now(), writer_id));
tracker.insert(sent_id, (std::time::Instant::now(), writer_id));
}
ping_id = ping_id.wrapping_add(1);
if let Err(e) = rpc_w_ping.lock().await.send(&p).await {
debug!(error = %e, "Active ME ping failed, removing dead writer");
cancel_ping.cancel();