Merge pull request #113 from telemt/me-fixes

Me fixes
This commit is contained in:
Alexey
2026-02-17 04:26:20 +03:00
committed by GitHub
14 changed files with 680 additions and 116 deletions

View File

@@ -9,7 +9,7 @@ libc = "0.2"
# Async runtime # Async runtime
tokio = { version = "1.42", features = ["full", "tracing"] } tokio = { version = "1.42", features = ["full", "tracing"] }
tokio-util = { version = "0.7", features = ["codec"] } tokio-util = { version = "0.7", features = ["full"] }
# Crypto # Crypto
aes = "0.8" aes = "0.8"
@@ -53,6 +53,7 @@ reqwest = { version = "0.12", features = ["rustls-tls"], default-features = fals
hyper = { version = "1", features = ["server", "http1"] } hyper = { version = "1", features = ["server", "http1"] }
hyper-util = { version = "0.1", features = ["tokio", "server-auto"] } hyper-util = { version = "0.1", features = ["tokio", "server-auto"] }
http-body-util = "0.1" http-body-util = "0.1"
httpdate = "1.0"
[dev-dependencies] [dev-dependencies]
tokio-test = "0.4" tokio-test = "0.4"

View File

@@ -256,10 +256,22 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
if probe.local_addr.ip() != probe.reflected_addr.ip() if probe.local_addr.ip() != probe.reflected_addr.ip()
&& !config.general.stun_iface_mismatch_ignore && !config.general.stun_iface_mismatch_ignore
{ {
warn!( match crate::transport::middle_proxy::detect_public_ip().await {
"STUN/IP-on-Interface mismatch -> fallback to direct-DC" Some(ip) => {
); info!(
use_middle_proxy = false; local_ip = %probe.local_addr.ip(),
reflected_ip = %probe.reflected_addr.ip(),
public_ip = %ip,
"STUN mismatch but public IP auto-detected, continuing with middle proxy"
);
}
None => {
warn!(
"STUN/IP-on-Interface mismatch and public IP auto-detect failed -> fallback to direct-DC"
);
use_middle_proxy = false;
}
}
} }
} }
Ok(None) => warn!("STUN probe returned no address; continuing"), Ok(None) => warn!("STUN probe returned no address; continuing"),
@@ -355,6 +367,18 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
.await; .await;
}); });
// Periodic ME connection rotation
let pool_clone_rot = pool.clone();
let rng_clone_rot = rng.clone();
tokio::spawn(async move {
crate::transport::middle_proxy::me_rotation_task(
pool_clone_rot,
rng_clone_rot,
std::time::Duration::from_secs(1800),
)
.await;
});
// Periodic updater: getProxyConfig + proxy-secret // Periodic updater: getProxyConfig + proxy-secret
let pool_clone2 = pool.clone(); let pool_clone2 = pool.clone();
let rng_clone2 = rng.clone(); let rng_clone2 = rng.clone();

View File

@@ -174,6 +174,7 @@ impl RpcWriter {
if buf.len() >= 16 { if buf.len() >= 16 {
self.iv.copy_from_slice(&buf[buf.len() - 16..]); self.iv.copy_from_slice(&buf[buf.len() - 16..]);
} }
self.writer.write_all(&buf).await.map_err(ProxyError::Io) self.writer.write_all(&buf).await.map_err(ProxyError::Io)?;
self.writer.flush().await.map_err(ProxyError::Io)
} }
} }

View File

@@ -4,6 +4,7 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use regex::Regex; use regex::Regex;
use httpdate;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use crate::error::Result; use crate::error::Result;
@@ -11,6 +12,7 @@ use crate::error::Result;
use super::MePool; use super::MePool;
use super::secret::download_proxy_secret; use super::secret::download_proxy_secret;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use std::time::SystemTime;
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct ProxyConfigData { pub struct ProxyConfigData {
@@ -19,9 +21,29 @@ pub struct ProxyConfigData {
} }
pub async fn fetch_proxy_config(url: &str) -> Result<ProxyConfigData> { pub async fn fetch_proxy_config(url: &str) -> Result<ProxyConfigData> {
let text = reqwest::get(url) let resp = reqwest::get(url)
.await .await
.map_err(|e| crate::error::ProxyError::Proxy(format!("fetch_proxy_config GET failed: {e}")))? .map_err(|e| crate::error::ProxyError::Proxy(format!("fetch_proxy_config GET failed: {e}")))?
;
if let Some(date) = resp.headers().get(reqwest::header::DATE) {
if let Ok(date_str) = date.to_str() {
if let Ok(server_time) = httpdate::parse_http_date(date_str) {
if let Ok(skew) = SystemTime::now().duration_since(server_time).or_else(|e| {
server_time.duration_since(SystemTime::now()).map_err(|_| e)
}) {
let skew_secs = skew.as_secs();
if skew_secs > 60 {
warn!(skew_secs, "Time skew >60s detected from fetch_proxy_config Date header");
} else if skew_secs > 30 {
warn!(skew_secs, "Time skew >30s detected from fetch_proxy_config Date header");
}
}
}
}
}
let text = resp
.text() .text()
.await .await
.map_err(|e| crate::error::ProxyError::Proxy(format!("fetch_proxy_config read failed: {e}")))?; .map_err(|e| crate::error::ProxyError::Proxy(format!("fetch_proxy_config read failed: {e}")))?;

View File

@@ -1,5 +1,12 @@
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use socket2::{SockRef, TcpKeepalive};
#[cfg(target_os = "linux")]
use libc;
#[cfg(target_os = "linux")]
use std::os::fd::{AsRawFd, RawFd};
#[cfg(target_os = "linux")]
use std::os::raw::c_int;
use bytes::BytesMut; use bytes::BytesMut;
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}; use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
@@ -41,9 +48,45 @@ impl MePool {
.map_err(|_| ProxyError::ConnectionTimeout { addr: addr.to_string() })??; .map_err(|_| ProxyError::ConnectionTimeout { addr: addr.to_string() })??;
let connect_ms = start.elapsed().as_secs_f64() * 1000.0; let connect_ms = start.elapsed().as_secs_f64() * 1000.0;
stream.set_nodelay(true).ok(); stream.set_nodelay(true).ok();
if let Err(e) = Self::configure_keepalive(&stream) {
warn!(error = %e, "ME keepalive setup failed");
}
#[cfg(target_os = "linux")]
if let Err(e) = Self::configure_user_timeout(stream.as_raw_fd()) {
warn!(error = %e, "ME TCP_USER_TIMEOUT setup failed");
}
Ok((stream, connect_ms)) Ok((stream, connect_ms))
} }
fn configure_keepalive(stream: &TcpStream) -> std::io::Result<()> {
let sock = SockRef::from(stream);
let ka = TcpKeepalive::new()
.with_time(Duration::from_secs(30))
.with_interval(Duration::from_secs(10))
.with_retries(3);
sock.set_tcp_keepalive(&ka)?;
sock.set_keepalive(true)?;
Ok(())
}
#[cfg(target_os = "linux")]
fn configure_user_timeout(fd: RawFd) -> std::io::Result<()> {
let timeout_ms: c_int = 30_000;
let rc = unsafe {
libc::setsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_USER_TIMEOUT,
&timeout_ms as *const _ as *const libc::c_void,
std::mem::size_of_val(&timeout_ms) as libc::socklen_t,
)
};
if rc != 0 {
return Err(std::io::Error::last_os_error());
}
Ok(())
}
/// Perform full ME RPC handshake on an established TCP stream. /// Perform full ME RPC handshake on an established TCP stream.
/// Returns cipher keys/ivs and split halves; does not register writer. /// Returns cipher keys/ivs and split halves; does not register writer.
pub(crate) async fn handshake_only( pub(crate) async fn handshake_only(

View File

@@ -1,6 +1,7 @@
use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::{Duration, Instant};
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
@@ -10,6 +11,8 @@ use crate::crypto::SecureRandom;
use super::MePool; use super::MePool;
pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, _min_connections: usize) { pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, _min_connections: usize) {
let mut backoff: HashMap<i32, u64> = HashMap::new();
let mut last_attempt: HashMap<i32, Instant> = HashMap::new();
loop { loop {
tokio::time::sleep(Duration::from_secs(30)).await; tokio::time::sleep(Duration::from_secs(30)).await;
// Per-DC coverage check // Per-DC coverage check
@@ -19,7 +22,7 @@ pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, _min_c
.read() .read()
.await .await
.iter() .iter()
.map(|(a, _)| *a) .map(|w| w.addr)
.collect(); .collect();
for (dc, addrs) in map.iter() { for (dc, addrs) in map.iter() {
@@ -29,18 +32,81 @@ pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, _min_c
.collect(); .collect();
let has_coverage = dc_addrs.iter().any(|a| writer_addrs.contains(a)); let has_coverage = dc_addrs.iter().any(|a| writer_addrs.contains(a));
if !has_coverage { if !has_coverage {
warn!(dc = %dc, "DC has no ME coverage, reconnecting..."); let delay = *backoff.get(dc).unwrap_or(&30);
let now = Instant::now();
if let Some(last) = last_attempt.get(dc) {
if now.duration_since(*last).as_secs() < delay {
continue;
}
}
warn!(dc = %dc, delay, "DC has no ME coverage, reconnecting...");
let mut shuffled = dc_addrs.clone(); let mut shuffled = dc_addrs.clone();
shuffled.shuffle(&mut rand::rng()); shuffled.shuffle(&mut rand::rng());
let mut reconnected = false;
for addr in shuffled { for addr in shuffled {
match pool.connect_one(addr, &rng).await { match pool.connect_one(addr, &rng).await {
Ok(()) => { Ok(()) => {
info!(%addr, dc = %dc, "ME reconnected for DC coverage"); info!(%addr, dc = %dc, "ME reconnected for DC coverage");
backoff.insert(*dc, 30);
last_attempt.insert(*dc, now);
reconnected = true;
break; break;
} }
Err(e) => debug!(%addr, dc = %dc, error = %e, "ME reconnect failed"), Err(e) => debug!(%addr, dc = %dc, error = %e, "ME reconnect failed"),
} }
} }
if !reconnected {
let next = (*backoff.get(dc).unwrap_or(&30)).saturating_mul(2).min(300);
backoff.insert(*dc, next);
last_attempt.insert(*dc, now);
}
}
}
// IPv6 coverage check (if available)
let map_v6 = pool.proxy_map_v6.read().await.clone();
let writer_addrs_v6: std::collections::HashSet<SocketAddr> = pool
.writers
.read()
.await
.iter()
.map(|w| w.addr)
.collect();
for (dc, addrs) in map_v6.iter() {
let dc_addrs: Vec<SocketAddr> = addrs
.iter()
.map(|(ip, port)| SocketAddr::new(*ip, *port))
.collect();
let has_coverage = dc_addrs.iter().any(|a| writer_addrs_v6.contains(a));
if !has_coverage {
let delay = *backoff.get(dc).unwrap_or(&30);
let now = Instant::now();
if let Some(last) = last_attempt.get(dc) {
if now.duration_since(*last).as_secs() < delay {
continue;
}
}
warn!(dc = %dc, delay, "IPv6 DC has no ME coverage, reconnecting...");
let mut shuffled = dc_addrs.clone();
shuffled.shuffle(&mut rand::rng());
let mut reconnected = false;
for addr in shuffled {
match pool.connect_one(addr, &rng).await {
Ok(()) => {
info!(%addr, dc = %dc, "ME reconnected for IPv6 DC coverage");
backoff.insert(*dc, 30);
last_attempt.insert(*dc, now);
reconnected = true;
break;
}
Err(e) => debug!(%addr, dc = %dc, error = %e, "ME reconnect failed (IPv6)"),
}
}
if !reconnected {
let next = (*backoff.get(dc).unwrap_or(&30)).saturating_mul(2).min(300);
backoff.insert(*dc, next);
last_attempt.insert(*dc, now);
}
} }
} }
} }

View File

@@ -10,6 +10,7 @@ mod reader;
mod registry; mod registry;
mod send; mod send;
mod secret; mod secret;
mod rotation;
mod config_updater; mod config_updater;
mod wire; mod wire;
@@ -18,10 +19,11 @@ use bytes::Bytes;
pub use health::me_health_monitor; pub use health::me_health_monitor;
pub use ping::{run_me_ping, format_sample_line, MePingReport, MePingSample, MePingFamily}; pub use ping::{run_me_ping, format_sample_line, MePingReport, MePingSample, MePingFamily};
pub use pool::MePool; pub use pool::MePool;
pub use pool_nat::{stun_probe, StunProbeResult}; pub use pool_nat::{stun_probe, detect_public_ip, StunProbeResult};
pub use registry::ConnRegistry; pub use registry::ConnRegistry;
pub use secret::fetch_proxy_secret; pub use secret::fetch_proxy_secret;
pub use config_updater::{fetch_proxy_config, me_config_updater}; pub use config_updater::{fetch_proxy_config, me_config_updater};
pub use rotation::me_rotation_task;
pub use wire::proto_flags_for_tag; pub use wire::proto_flags_for_tag;
#[derive(Debug)] #[derive(Debug)]

View File

@@ -1,11 +1,12 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicI32, AtomicU64}; use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, Ordering};
use bytes::BytesMut; use bytes::BytesMut;
use rand::Rng; use rand::Rng;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use tokio::sync::{Mutex, RwLock}; use tokio::sync::{Mutex, RwLock};
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use std::time::Duration; use std::time::Duration;
@@ -14,15 +15,26 @@ use crate::error::{ProxyError, Result};
use crate::protocol::constants::*; use crate::protocol::constants::*;
use super::ConnRegistry; use super::ConnRegistry;
use super::registry::{BoundConn, ConnMeta};
use super::codec::RpcWriter; use super::codec::RpcWriter;
use super::reader::reader_loop; use super::reader::reader_loop;
use super::MeResponse;
const ME_ACTIVE_PING_SECS: u64 = 25; const ME_ACTIVE_PING_SECS: u64 = 25;
const ME_ACTIVE_PING_JITTER_SECS: i64 = 5; const ME_ACTIVE_PING_JITTER_SECS: i64 = 5;
#[derive(Clone)]
pub struct MeWriter {
pub id: u64,
pub addr: SocketAddr,
pub writer: Arc<Mutex<RpcWriter>>,
pub cancel: CancellationToken,
pub degraded: Arc<AtomicBool>,
}
pub struct MePool { pub struct MePool {
pub(super) registry: Arc<ConnRegistry>, pub(super) registry: Arc<ConnRegistry>,
pub(super) writers: Arc<RwLock<Vec<(SocketAddr, Arc<Mutex<RpcWriter>>)>>> , pub(super) writers: Arc<RwLock<Vec<MeWriter>>>,
pub(super) rr: AtomicU64, pub(super) rr: AtomicU64,
pub(super) proxy_tag: Option<Vec<u8>>, pub(super) proxy_tag: Option<Vec<u8>>,
pub(super) proxy_secret: Arc<RwLock<Vec<u8>>>, pub(super) proxy_secret: Arc<RwLock<Vec<u8>>>,
@@ -33,6 +45,10 @@ pub struct MePool {
pub(super) proxy_map_v4: Arc<RwLock<HashMap<i32, Vec<(IpAddr, u16)>>>>, pub(super) proxy_map_v4: Arc<RwLock<HashMap<i32, Vec<(IpAddr, u16)>>>>,
pub(super) proxy_map_v6: Arc<RwLock<HashMap<i32, Vec<(IpAddr, u16)>>>>, pub(super) proxy_map_v6: Arc<RwLock<HashMap<i32, Vec<(IpAddr, u16)>>>>,
pub(super) default_dc: AtomicI32, pub(super) default_dc: AtomicI32,
pub(super) next_writer_id: AtomicU64,
pub(super) ping_tracker: Arc<Mutex<HashMap<i64, (std::time::Instant, u64)>>>,
pub(super) rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
pub(super) nat_reflection_cache: Arc<Mutex<Option<(std::time::Instant, std::net::SocketAddr)>>>,
pool_size: usize, pool_size: usize,
} }
@@ -61,6 +77,10 @@ impl MePool {
proxy_map_v4: Arc::new(RwLock::new(proxy_map_v4)), proxy_map_v4: Arc::new(RwLock::new(proxy_map_v4)),
proxy_map_v6: Arc::new(RwLock::new(proxy_map_v6)), proxy_map_v6: Arc::new(RwLock::new(proxy_map_v6)),
default_dc: AtomicI32::new(default_dc.unwrap_or(0)), default_dc: AtomicI32::new(default_dc.unwrap_or(0)),
next_writer_id: AtomicU64::new(1),
ping_tracker: Arc::new(Mutex::new(HashMap::new())),
rtt_stats: Arc::new(Mutex::new(HashMap::new())),
nat_reflection_cache: Arc::new(Mutex::new(None)),
}) })
} }
@@ -77,16 +97,19 @@ impl MePool {
&self.registry &self.registry
} }
fn writers_arc(&self) -> Arc<RwLock<Vec<(SocketAddr, Arc<Mutex<RpcWriter>>)>>> fn writers_arc(&self) -> Arc<RwLock<Vec<MeWriter>>> {
{
self.writers.clone() self.writers.clone()
} }
pub async fn reconcile_connections(&self, rng: &SecureRandom) { pub async fn reconcile_connections(self: &Arc<Self>, rng: &SecureRandom) {
use std::collections::HashSet; use std::collections::HashSet;
let map = self.proxy_map_v4.read().await.clone(); let map = self.proxy_map_v4.read().await.clone();
let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map
.iter()
.map(|(dc, addrs)| (*dc, addrs.clone()))
.collect();
let writers = self.writers.read().await; let writers = self.writers.read().await;
let current: HashSet<SocketAddr> = writers.iter().map(|(a, _)| *a).collect(); let current: HashSet<SocketAddr> = writers.iter().map(|w| w.addr).collect();
drop(writers); drop(writers);
for (_dc, addrs) in map.iter() { for (_dc, addrs) in map.iter() {
@@ -158,8 +181,12 @@ impl MePool {
} }
} }
pub async fn init(self: &Arc<Self>, pool_size: usize, rng: &SecureRandom) -> Result<()> { pub async fn init(self: &Arc<Self>, pool_size: usize, rng: &Arc<SecureRandom>) -> Result<()> {
let map = self.proxy_map_v4.read().await; let map = self.proxy_map_v4.read().await.clone();
let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map
.iter()
.map(|(dc, addrs)| (*dc, addrs.clone()))
.collect();
let ks = self.key_selector().await; let ks = self.key_selector().await;
info!( info!(
me_servers = map.len(), me_servers = map.len(),
@@ -169,38 +196,28 @@ impl MePool {
"Initializing ME pool" "Initializing ME pool"
); );
// Ensure at least one connection per DC with failover over all addresses // Ensure at least one connection per DC; run DCs in parallel.
for (dc, addrs) in map.iter() { let mut join = tokio::task::JoinSet::new();
for (dc, addrs) in dc_addrs.iter().cloned() {
if addrs.is_empty() { if addrs.is_empty() {
continue; continue;
} }
let mut connected = false; let pool = Arc::clone(self);
let mut shuffled = addrs.clone(); let rng_clone = Arc::clone(rng);
shuffled.shuffle(&mut rand::rng()); join.spawn(async move {
for (ip, port) in shuffled { pool.connect_primary_for_dc(dc, addrs, rng_clone).await;
let addr = SocketAddr::new(ip, port); });
match self.connect_one(addr, rng).await {
Ok(()) => {
info!(%addr, dc = %dc, "ME connected");
connected = true;
break;
}
Err(e) => warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next"),
}
}
if !connected {
warn!(dc = %dc, "All ME servers for DC failed at init");
}
} }
while let Some(_res) = join.join_next().await {}
// Additional connections up to pool_size total (round-robin across DCs) // Additional connections up to pool_size total (round-robin across DCs)
for (dc, addrs) in map.iter() { for (dc, addrs) in dc_addrs.iter() {
for (ip, port) in addrs { for (ip, port) in addrs {
if self.connection_count() >= pool_size { if self.connection_count() >= pool_size {
break; break;
} }
let addr = SocketAddr::new(*ip, *port); let addr = SocketAddr::new(*ip, *port);
if let Err(e) = self.connect_one(addr, rng).await { if let Err(e) = self.connect_one(addr, rng.as_ref()).await {
debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed"); debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed");
} }
} }
@@ -215,7 +232,7 @@ impl MePool {
Ok(()) Ok(())
} }
pub(crate) async fn connect_one(&self, addr: SocketAddr, rng: &SecureRandom) -> Result<()> { pub(crate) async fn connect_one(self: &Arc<Self>, addr: SocketAddr, rng: &SecureRandom) -> Result<()> {
let secret_len = self.proxy_secret.read().await.len(); let secret_len = self.proxy_secret.read().await.len();
if secret_len < 32 { if secret_len < 32 {
return Err(ProxyError::Proxy("proxy-secret too short for ME auth".into())); return Err(ProxyError::Proxy("proxy-secret too short for ME auth".into()));
@@ -224,44 +241,88 @@ impl MePool {
let (stream, _connect_ms) = self.connect_tcp(addr).await?; let (stream, _connect_ms) = self.connect_tcp(addr).await?;
let hs = self.handshake_only(stream, addr, rng).await?; let hs = self.handshake_only(stream, addr, rng).await?;
let writer_id = self.next_writer_id.fetch_add(1, Ordering::Relaxed);
let cancel = CancellationToken::new();
let degraded = Arc::new(AtomicBool::new(false));
let rpc_w = Arc::new(Mutex::new(RpcWriter { let rpc_w = Arc::new(Mutex::new(RpcWriter {
writer: hs.wr, writer: hs.wr,
key: hs.write_key, key: hs.write_key,
iv: hs.write_iv, iv: hs.write_iv,
seq_no: 0, seq_no: 0,
})); }));
self.writers.write().await.push((addr, rpc_w.clone())); let writer = MeWriter {
id: writer_id,
addr,
writer: rpc_w.clone(),
cancel: cancel.clone(),
degraded: degraded.clone(),
};
self.writers.write().await.push(writer.clone());
let reg = self.registry.clone(); let reg = self.registry.clone();
let w_pong = rpc_w.clone(); let writers_arc = self.writers_arc();
let w_pool = self.writers_arc(); let ping_tracker = self.ping_tracker.clone();
let w_ping = rpc_w.clone(); let rtt_stats = self.rtt_stats.clone();
let w_pool_ping = self.writers_arc(); let pool = Arc::downgrade(self);
let cancel_ping = cancel.clone();
let rpc_w_ping = rpc_w.clone();
let ping_tracker_ping = ping_tracker.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = let cancel_reader = cancel.clone();
reader_loop(hs.rd, hs.read_key, hs.read_iv, reg, BytesMut::new(), BytesMut::new(), w_pong.clone()).await let res = reader_loop(
{ hs.rd,
hs.read_key,
hs.read_iv,
reg.clone(),
BytesMut::new(),
BytesMut::new(),
rpc_w.clone(),
ping_tracker.clone(),
rtt_stats.clone(),
writer_id,
degraded.clone(),
cancel_reader.clone(),
)
.await;
if let Some(pool) = pool.upgrade() {
pool.remove_writer_and_reroute(writer_id).await;
}
if let Err(e) = res {
warn!(error = %e, "ME reader ended"); warn!(error = %e, "ME reader ended");
} }
let mut ws = w_pool.write().await; let mut ws = writers_arc.write().await;
ws.retain(|(_, w)| !Arc::ptr_eq(w, &w_pong)); ws.retain(|w| w.id != writer_id);
info!(remaining = ws.len(), "Dead ME writer removed from pool"); info!(remaining = ws.len(), "Dead ME writer removed from pool");
}); });
let pool_ping = Arc::downgrade(self);
tokio::spawn(async move { tokio::spawn(async move {
let mut ping_id: i64 = rand::random::<i64>(); let mut ping_id: i64 = rand::random::<i64>();
loop { loop {
let jitter = rand::rng() let jitter = rand::rng()
.random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); .random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS);
let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64;
tokio::time::sleep(Duration::from_secs(wait)).await; tokio::select! {
_ = cancel_ping.cancelled() => {
break;
}
_ = tokio::time::sleep(Duration::from_secs(wait)) => {}
}
let mut p = Vec::with_capacity(12); let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_PING_U32.to_le_bytes()); p.extend_from_slice(&RPC_PING_U32.to_le_bytes());
p.extend_from_slice(&ping_id.to_le_bytes()); p.extend_from_slice(&ping_id.to_le_bytes());
ping_id = ping_id.wrapping_add(1); ping_id = ping_id.wrapping_add(1);
if let Err(e) = w_ping.lock().await.send(&p).await { {
let mut tracker = ping_tracker_ping.lock().await;
tracker.insert(ping_id, (std::time::Instant::now(), writer_id));
}
if let Err(e) = rpc_w_ping.lock().await.send(&p).await {
debug!(error = %e, "Active ME ping failed, removing dead writer"); debug!(error = %e, "Active ME ping failed, removing dead writer");
let mut ws = w_pool_ping.write().await; cancel_ping.cancel();
ws.retain(|(_, w)| !Arc::ptr_eq(w, &w_ping)); if let Some(pool) = pool_ping.upgrade() {
pool.remove_writer_and_reroute(writer_id).await;
}
break; break;
} }
} }
@@ -270,6 +331,124 @@ impl MePool {
Ok(()) Ok(())
} }
async fn connect_primary_for_dc(
self: Arc<Self>,
dc: i32,
mut addrs: Vec<(IpAddr, u16)>,
rng: Arc<SecureRandom>,
) {
if addrs.is_empty() {
return;
}
addrs.shuffle(&mut rand::rng());
for (ip, port) in addrs {
let addr = SocketAddr::new(ip, port);
match self.connect_one(addr, rng.as_ref()).await {
Ok(()) => {
info!(%addr, dc = %dc, "ME connected");
return;
}
Err(e) => warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next"),
}
}
warn!(dc = %dc, "All ME servers for DC failed at init");
}
pub(crate) async fn remove_writer_and_reroute(&self, writer_id: u64) {
let mut queue = self.remove_writer_only(writer_id).await;
while let Some(bound) = queue.pop() {
if !self.reroute_conn(&bound, &mut queue).await {
let _ = self.registry.route(bound.conn_id, super::MeResponse::Close).await;
}
}
}
async fn remove_writer_only(&self, writer_id: u64) -> Vec<BoundConn> {
{
let mut ws = self.writers.write().await;
if let Some(pos) = ws.iter().position(|w| w.id == writer_id) {
let w = ws.remove(pos);
w.cancel.cancel();
}
}
self.registry.writer_lost(writer_id).await
}
async fn reroute_conn(&self, bound: &BoundConn, backlog: &mut Vec<BoundConn>) -> bool {
let payload = super::wire::build_proxy_req_payload(
bound.conn_id,
bound.meta.client_addr,
bound.meta.our_addr,
&[],
self.proxy_tag.as_deref(),
bound.meta.proto_flags,
);
let mut attempts = 0;
loop {
let writers_snapshot = {
let ws = self.writers.read().await;
if ws.is_empty() {
return false;
}
ws.clone()
};
let mut candidates = self.candidate_indices_for_dc(&writers_snapshot, bound.meta.target_dc).await;
if candidates.is_empty() {
return false;
}
candidates.sort_by_key(|idx| {
writers_snapshot[*idx]
.degraded
.load(Ordering::Relaxed)
.then_some(1usize)
.unwrap_or(0)
});
let start = self.rr.fetch_add(1, Ordering::Relaxed) as usize % candidates.len();
for offset in 0..candidates.len() {
let idx = candidates[(start + offset) % candidates.len()];
let w = &writers_snapshot[idx];
if let Ok(mut guard) = w.writer.try_lock() {
let send_res = guard.send(&payload).await;
drop(guard);
match send_res {
Ok(()) => {
self.registry
.bind_writer(bound.conn_id, w.id, w.writer.clone(), bound.meta.clone())
.await;
return true;
}
Err(e) => {
warn!(error = %e, writer_id = w.id, "ME reroute send failed");
backlog.extend(self.remove_writer_only(w.id).await);
}
}
continue;
}
}
let w = writers_snapshot[candidates[start]].clone();
match w.writer.lock().await.send(&payload).await {
Ok(()) => {
self.registry
.bind_writer(bound.conn_id, w.id, w.writer.clone(), bound.meta.clone())
.await;
return true;
}
Err(e) => {
warn!(error = %e, writer_id = w.id, "ME reroute send failed (blocking)");
backlog.extend(self.remove_writer_only(w.id).await);
}
}
attempts += 1;
if attempts > 3 {
return false;
}
}
}
} }
fn hex_dump(data: &[u8]) -> String { fn hex_dump(data: &[u8]) -> String {

View File

@@ -1,10 +1,12 @@
use std::net::{IpAddr, Ipv4Addr}; use std::net::{IpAddr, Ipv4Addr};
use std::time::Duration;
use tracing::{info, warn}; use tracing::{info, warn};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use super::MePool; use super::MePool;
use std::time::Instant;
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct StunProbeResult { pub struct StunProbeResult {
@@ -17,6 +19,10 @@ pub async fn stun_probe(stun_addr: Option<String>) -> Result<Option<StunProbeRes
fetch_stun_binding(&stun_addr).await fetch_stun_binding(&stun_addr).await
} }
pub async fn detect_public_ip() -> Option<IpAddr> {
fetch_public_ipv4_with_retry().await.ok().flatten().map(IpAddr::V4)
}
impl MePool { impl MePool {
pub(super) fn translate_ip_for_nat(&self, ip: IpAddr) -> IpAddr { pub(super) fn translate_ip_for_nat(&self, ip: IpAddr) -> IpAddr {
let nat_ip = self let nat_ip = self
@@ -93,6 +99,15 @@ impl MePool {
} }
pub(super) async fn maybe_reflect_public_addr(&self) -> Option<std::net::SocketAddr> { pub(super) async fn maybe_reflect_public_addr(&self) -> Option<std::net::SocketAddr> {
const STUN_CACHE_TTL: Duration = Duration::from_secs(600);
if let Ok(mut cache) = self.nat_reflection_cache.try_lock() {
if let Some((ts, addr)) = *cache {
if ts.elapsed() < STUN_CACHE_TTL {
return Some(addr);
}
}
}
let stun_addr = self let stun_addr = self
.nat_stun .nat_stun
.clone() .clone()
@@ -101,6 +116,9 @@ impl MePool {
Ok(sa) => { Ok(sa) => {
if let Some(result) = sa { if let Some(result) = sa {
info!(local = %result.local_addr, reflected = %result.reflected_addr, "NAT probe: reflected address"); info!(local = %result.local_addr, reflected = %result.reflected_addr, "NAT probe: reflected address");
if let Ok(mut cache) = self.nat_reflection_cache.try_lock() {
*cache = Some((Instant::now(), result.reflected_addr));
}
Some(result.reflected_addr) Some(result.reflected_addr)
} else { } else {
None None

View File

@@ -1,9 +1,13 @@
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Instant;
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;
use tracing::{debug, trace, warn}; use tracing::{debug, trace, warn};
use crate::crypto::{AesCbc, crc32}; use crate::crypto::{AesCbc, crc32};
@@ -21,12 +25,21 @@ pub(crate) async fn reader_loop(
enc_leftover: BytesMut, enc_leftover: BytesMut,
mut dec: BytesMut, mut dec: BytesMut,
writer: Arc<Mutex<RpcWriter>>, writer: Arc<Mutex<RpcWriter>>,
ping_tracker: Arc<Mutex<HashMap<i64, (Instant, u64)>>>,
rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
_writer_id: u64,
degraded: Arc<AtomicBool>,
cancel: CancellationToken,
) -> Result<()> { ) -> Result<()> {
let mut raw = enc_leftover; let mut raw = enc_leftover;
let mut expected_seq: i32 = 0;
loop { loop {
let mut tmp = [0u8; 16_384]; let mut tmp = [0u8; 16_384];
let n = rd.read(&mut tmp).await.map_err(ProxyError::Io)?; let n = tokio::select! {
res = rd.read(&mut tmp) => res.map_err(ProxyError::Io)?,
_ = cancel.cancelled() => return Ok(()),
};
if n == 0 { if n == 0 {
return Ok(()); return Ok(());
} }
@@ -70,6 +83,14 @@ pub(crate) async fn reader_loop(
continue; continue;
} }
let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap());
if seq_no != expected_seq {
warn!(seq_no, expected = expected_seq, "ME RPC seq mismatch");
expected_seq = seq_no.wrapping_add(1);
} else {
expected_seq = expected_seq.wrapping_add(1);
}
let payload = &frame[8..pe]; let payload = &frame[8..pe];
if payload.len() < 4 { if payload.len() < 4 {
continue; continue;
@@ -119,6 +140,23 @@ pub(crate) async fn reader_loop(
warn!(error = %e, "PONG send failed"); warn!(error = %e, "PONG send failed");
break; break;
} }
} else if pt == RPC_PONG_U32 && body.len() >= 8 {
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
if let Some((sent, wid)) = {
let mut guard = ping_tracker.lock().await;
guard.remove(&ping_id)
} {
let rtt = sent.elapsed().as_secs_f64() * 1000.0;
let mut stats = rtt_stats.lock().await;
let entry = stats.entry(wid).or_insert((rtt, rtt));
entry.1 = entry.1 * 0.8 + rtt * 0.2;
if rtt < entry.0 {
entry.0 = rtt;
}
let degraded_now = entry.1 > entry.0 * 2.0;
degraded.store(degraded_now, Ordering::Relaxed);
trace!(writer_id = wid, rtt_ms = rtt, ema_ms = entry.1, base_ms = entry.0, degraded = degraded_now, "ME RTT sample");
}
} else { } else {
debug!( debug!(
rpc_type = format_args!("0x{pt:08x}"), rpc_type = format_args!("0x{pt:08x}"),

View File

@@ -1,60 +1,133 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::{RwLock, mpsc};
use super::MeResponse;
use super::codec::RpcWriter;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::{mpsc, Mutex, RwLock};
use super::codec::RpcWriter;
use super::MeResponse;
#[derive(Clone)]
pub struct ConnMeta {
pub target_dc: i16,
pub client_addr: SocketAddr,
pub our_addr: SocketAddr,
pub proto_flags: u32,
}
#[derive(Clone)]
pub struct BoundConn {
pub conn_id: u64,
pub meta: ConnMeta,
}
#[derive(Clone)]
pub struct ConnWriter {
pub writer_id: u64,
pub writer: Arc<Mutex<RpcWriter>>,
}
pub struct ConnRegistry { pub struct ConnRegistry {
map: RwLock<HashMap<u64, mpsc::UnboundedSender<MeResponse>>>, map: RwLock<HashMap<u64, mpsc::Sender<MeResponse>>>,
writers: RwLock<HashMap<u64, Arc<Mutex<RpcWriter>>>>, writers: RwLock<HashMap<u64, Arc<Mutex<RpcWriter>>>>,
writer_for_conn: RwLock<HashMap<u64, u64>>,
conns_for_writer: RwLock<HashMap<u64, Vec<u64>>>,
meta: RwLock<HashMap<u64, ConnMeta>>,
next_id: AtomicU64, next_id: AtomicU64,
} }
impl ConnRegistry { impl ConnRegistry {
pub fn new() -> Self { pub fn new() -> Self {
// Avoid fully predictable conn_id sequence from 1.
let start = rand::random::<u64>() | 1; let start = rand::random::<u64>() | 1;
Self { Self {
map: RwLock::new(HashMap::new()), map: RwLock::new(HashMap::new()),
writers: RwLock::new(HashMap::new()), writers: RwLock::new(HashMap::new()),
writer_for_conn: RwLock::new(HashMap::new()),
conns_for_writer: RwLock::new(HashMap::new()),
meta: RwLock::new(HashMap::new()),
next_id: AtomicU64::new(start), next_id: AtomicU64::new(start),
} }
} }
pub async fn register(&self) -> (u64, mpsc::UnboundedReceiver<MeResponse>) { pub async fn register(&self) -> (u64, mpsc::Receiver<MeResponse>) {
let id = self.next_id.fetch_add(1, Ordering::Relaxed); let id = self.next_id.fetch_add(1, Ordering::Relaxed);
// Unbounded per-connection queue prevents reader-loop HOL blocking on let (tx, rx) = mpsc::channel(1024);
// slow clients: routing stays non-blocking and preserves message order.
let (tx, rx) = mpsc::unbounded_channel();
self.map.write().await.insert(id, tx); self.map.write().await.insert(id, tx);
(id, rx) (id, rx)
} }
pub async fn unregister(&self, id: u64) { pub async fn unregister(&self, id: u64) {
self.map.write().await.remove(&id); self.map.write().await.remove(&id);
self.writers.write().await.remove(&id); self.meta.write().await.remove(&id);
if let Some(writer_id) = self.writer_for_conn.write().await.remove(&id) {
if let Some(list) = self.conns_for_writer.write().await.get_mut(&writer_id) {
list.retain(|c| *c != id);
}
}
} }
pub async fn route(&self, id: u64, resp: MeResponse) -> bool { pub async fn route(&self, id: u64, resp: MeResponse) -> bool {
let m = self.map.read().await; let m = self.map.read().await;
if let Some(tx) = m.get(&id) { if let Some(tx) = m.get(&id) {
tx.send(resp).is_ok() tx.try_send(resp).is_ok()
} else { } else {
false false
} }
} }
pub async fn set_writer(&self, id: u64, w: Arc<Mutex<RpcWriter>>) { pub async fn bind_writer(
let mut guard = self.writers.write().await; &self,
guard.entry(id).or_insert_with(|| w); conn_id: u64,
writer_id: u64,
writer: Arc<Mutex<RpcWriter>>,
meta: ConnMeta,
) {
self.meta.write().await.entry(conn_id).or_insert(meta);
self.writer_for_conn.write().await.insert(conn_id, writer_id);
self.writers.write().await.entry(writer_id).or_insert_with(|| writer.clone());
self.conns_for_writer
.write()
.await
.entry(writer_id)
.or_insert_with(Vec::new)
.push(conn_id);
} }
pub async fn get_writer(&self, id: u64) -> Option<Arc<Mutex<RpcWriter>>> { pub async fn get_writer(&self, conn_id: u64) -> Option<ConnWriter> {
let guard = self.writers.read().await; let writer_id = {
guard.get(&id).cloned() let guard = self.writer_for_conn.read().await;
guard.get(&conn_id).cloned()
}?;
let writer = {
let guard = self.writers.read().await;
guard.get(&writer_id).cloned()
}?;
Some(ConnWriter { writer_id, writer })
}
pub async fn writer_lost(&self, writer_id: u64) -> Vec<BoundConn> {
self.writers.write().await.remove(&writer_id);
let conns = self.conns_for_writer.write().await.remove(&writer_id).unwrap_or_default();
let mut out = Vec::new();
let mut writer_for_conn = self.writer_for_conn.write().await;
let meta = self.meta.read().await;
for conn_id in conns {
writer_for_conn.remove(&conn_id);
if let Some(m) = meta.get(&conn_id) {
out.push(BoundConn {
conn_id,
meta: m.clone(),
});
}
}
out
}
pub async fn get_meta(&self, conn_id: u64) -> Option<ConnMeta> {
let guard = self.meta.read().await;
guard.get(&conn_id).cloned()
} }
} }

View File

@@ -0,0 +1,37 @@
use std::sync::Arc;
use std::time::Duration;
use tracing::{info, warn};
use crate::crypto::SecureRandom;
use super::MePool;
/// Periodically refresh ME connections to avoid long-lived degradation.
pub async fn me_rotation_task(pool: Arc<MePool>, rng: Arc<SecureRandom>, interval: Duration) {
let interval = interval.max(Duration::from_secs(600));
loop {
tokio::time::sleep(interval).await;
let candidate = {
let ws = pool.writers.read().await;
ws.get(0).cloned()
};
let Some(w) = candidate else {
continue;
};
info!(addr = %w.addr, writer_id = w.id, "Rotating ME connection");
match pool.connect_one(w.addr, rng.as_ref()).await {
Ok(()) => {
// Remove old writer after new one is up.
pool.remove_writer_and_reroute(w.id).await;
}
Err(e) => {
warn!(addr = %w.addr, writer_id = w.id, error = %e, "ME rotation connect failed");
}
}
}
}

View File

@@ -1,6 +1,8 @@
use std::time::Duration; use std::time::Duration;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use std::time::SystemTime;
use httpdate;
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
@@ -63,6 +65,23 @@ pub async fn download_proxy_secret() -> Result<Vec<u8>> {
))); )));
} }
if let Some(date) = resp.headers().get(reqwest::header::DATE) {
if let Ok(date_str) = date.to_str() {
if let Ok(server_time) = httpdate::parse_http_date(date_str) {
if let Ok(skew) = SystemTime::now().duration_since(server_time).or_else(|e| {
server_time.duration_since(SystemTime::now()).map_err(|_| e)
}) {
let skew_secs = skew.as_secs();
if skew_secs > 60 {
warn!(skew_secs, "Time skew >60s detected from proxy-secret Date header");
} else if skew_secs > 30 {
warn!(skew_secs, "Time skew >30s detected from proxy-secret Date header");
}
}
}
}
}
let data = resp let data = resp
.bytes() .bytes()
.await .await

View File

@@ -1,6 +1,7 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use std::time::Duration;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tracing::{debug, warn}; use tracing::{debug, warn};
@@ -9,14 +10,14 @@ use crate::error::{ProxyError, Result};
use crate::protocol::constants::RPC_CLOSE_EXT_U32; use crate::protocol::constants::RPC_CLOSE_EXT_U32;
use super::MePool; use super::MePool;
use super::codec::RpcWriter;
use super::wire::build_proxy_req_payload; use super::wire::build_proxy_req_payload;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use super::registry::ConnMeta;
impl MePool { impl MePool {
pub async fn send_proxy_req( pub async fn send_proxy_req(
&self, self: &Arc<Self>,
conn_id: u64, conn_id: u64,
target_dc: i16, target_dc: i16,
client_addr: SocketAddr, client_addr: SocketAddr,
@@ -32,18 +33,50 @@ impl MePool {
self.proxy_tag.as_deref(), self.proxy_tag.as_deref(),
proto_flags, proto_flags,
); );
let meta = ConnMeta {
target_dc,
client_addr,
our_addr,
proto_flags,
};
let mut emergency_attempts = 0;
loop { loop {
let ws = self.writers.read().await; if let Some(current) = self.registry.get_writer(conn_id).await {
if ws.is_empty() { let send_res = {
return Err(ProxyError::Proxy("All ME connections dead".into())); if let Ok(mut guard) = current.writer.try_lock() {
let r = guard.send(&payload).await;
drop(guard);
r
} else {
current.writer.lock().await.send(&payload).await
}
};
match send_res {
Ok(()) => return Ok(()),
Err(e) => {
warn!(error = %e, writer_id = current.writer_id, "ME write failed");
self.remove_writer_and_reroute(current.writer_id).await;
continue;
}
}
} }
let writers: Vec<(SocketAddr, Arc<Mutex<RpcWriter>>)> = ws.iter().cloned().collect();
drop(ws);
let mut candidate_indices = self.candidate_indices_for_dc(&writers, target_dc).await; let mut writers_snapshot = {
let ws = self.writers.read().await;
if ws.is_empty() {
return Err(ProxyError::Proxy("All ME connections dead".into()));
}
ws.clone()
};
let mut candidate_indices = self.candidate_indices_for_dc(&writers_snapshot, target_dc).await;
if candidate_indices.is_empty() { if candidate_indices.is_empty() {
// Emergency: try to connect to target DC addresses on the fly, then recompute writers // Emergency connect-on-demand
if emergency_attempts >= 3 {
return Err(ProxyError::Proxy("No ME writers available for target DC".into()));
}
emergency_attempts += 1;
let map = self.proxy_map_v4.read().await; let map = self.proxy_map_v4.read().await;
if let Some(addrs) = map.get(&(target_dc as i32)) { if let Some(addrs) = map.get(&(target_dc as i32)) {
let mut shuffled = addrs.clone(); let mut shuffled = addrs.clone();
@@ -55,65 +88,73 @@ impl MePool {
break; break;
} }
} }
tokio::time::sleep(Duration::from_millis(100 * emergency_attempts)).await;
let ws2 = self.writers.read().await; let ws2 = self.writers.read().await;
let writers: Vec<(SocketAddr, Arc<Mutex<RpcWriter>>)> = ws2.iter().cloned().collect(); writers_snapshot = ws2.clone();
drop(ws2); drop(ws2);
candidate_indices = self.candidate_indices_for_dc(&writers, target_dc).await; candidate_indices = self.candidate_indices_for_dc(&writers_snapshot, target_dc).await;
} }
if candidate_indices.is_empty() { if candidate_indices.is_empty() {
return Err(ProxyError::Proxy("No ME writers available for target DC".into())); return Err(ProxyError::Proxy("No ME writers available for target DC".into()));
} }
} }
candidate_indices.sort_by_key(|idx| {
writers_snapshot[*idx]
.degraded
.load(Ordering::Relaxed)
.then_some(1usize)
.unwrap_or(0)
});
let start = self.rr.fetch_add(1, Ordering::Relaxed) as usize % candidate_indices.len(); let start = self.rr.fetch_add(1, Ordering::Relaxed) as usize % candidate_indices.len();
// Prefer immediately available writer to avoid waiting on stalled connection.
for offset in 0..candidate_indices.len() { for offset in 0..candidate_indices.len() {
let cidx = (start + offset) % candidate_indices.len(); let idx = candidate_indices[(start + offset) % candidate_indices.len()];
let idx = candidate_indices[cidx]; let w = &writers_snapshot[idx];
let w = writers[idx].1.clone(); if let Ok(mut guard) = w.writer.try_lock() {
if let Ok(mut guard) = w.try_lock() {
let send_res = guard.send(&payload).await; let send_res = guard.send(&payload).await;
drop(guard); drop(guard);
match send_res { match send_res {
Ok(()) => return Ok(()), Ok(()) => {
self.registry
.bind_writer(conn_id, w.id, w.writer.clone(), meta.clone())
.await;
return Ok(());
}
Err(e) => { Err(e) => {
warn!(error = %e, "ME write failed, removing dead conn"); warn!(error = %e, writer_id = w.id, "ME write failed");
let mut ws = self.writers.write().await; self.remove_writer_and_reroute(w.id).await;
ws.retain(|(_, o)| !Arc::ptr_eq(o, &w));
if ws.is_empty() {
return Err(ProxyError::Proxy("All ME connections dead".into()));
}
continue; continue;
} }
} }
} }
} }
// All writers are currently busy, wait for the selected one. let w = writers_snapshot[candidate_indices[start]].clone();
let w = writers[candidate_indices[start]].1.clone(); match w.writer.lock().await.send(&payload).await {
match w.lock().await.send(&payload).await { Ok(()) => {
Ok(()) => return Ok(()), self.registry
.bind_writer(conn_id, w.id, w.writer.clone(), meta.clone())
.await;
return Ok(());
}
Err(e) => { Err(e) => {
warn!(error = %e, "ME write failed, removing dead conn"); warn!(error = %e, writer_id = w.id, "ME write failed (blocking)");
let mut ws = self.writers.write().await; self.remove_writer_and_reroute(w.id).await;
ws.retain(|(_, o)| !Arc::ptr_eq(o, &w));
if ws.is_empty() {
return Err(ProxyError::Proxy("All ME connections dead".into()));
}
} }
} }
} }
} }
pub async fn send_close(&self, conn_id: u64) -> Result<()> { pub async fn send_close(self: &Arc<Self>, conn_id: u64) -> Result<()> {
if let Some(w) = self.registry.get_writer(conn_id).await { if let Some(w) = self.registry.get_writer(conn_id).await {
let mut p = Vec::with_capacity(12); let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes()); p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes());
p.extend_from_slice(&conn_id.to_le_bytes()); p.extend_from_slice(&conn_id.to_le_bytes());
if let Err(e) = w.lock().await.send(&p).await { if let Err(e) = w.writer.lock().await.send(&p).await {
debug!(error = %e, "ME close write failed"); debug!(error = %e, "ME close write failed");
let mut ws = self.writers.write().await; self.remove_writer_and_reroute(w.writer_id).await;
ws.retain(|(_, o)| !Arc::ptr_eq(o, &w));
} }
} else { } else {
debug!(conn_id, "ME close skipped (writer missing)"); debug!(conn_id, "ME close skipped (writer missing)");
@@ -129,7 +170,7 @@ impl MePool {
pub(super) async fn candidate_indices_for_dc( pub(super) async fn candidate_indices_for_dc(
&self, &self,
writers: &[(SocketAddr, Arc<Mutex<RpcWriter>>)], writers: &[super::pool::MeWriter],
target_dc: i16, target_dc: i16,
) -> Vec<usize> { ) -> Vec<usize> {
let mut preferred = Vec::<SocketAddr>::new(); let mut preferred = Vec::<SocketAddr>::new();
@@ -165,8 +206,8 @@ impl MePool {
} }
let mut out = Vec::new(); let mut out = Vec::new();
for (idx, (addr, _)) in writers.iter().enumerate() { for (idx, w) in writers.iter().enumerate() {
if preferred.iter().any(|p| p == addr) { if preferred.iter().any(|p| *p == w.addr) {
out.push(idx); out.push(idx);
} }
} }