@@ -9,7 +9,7 @@ libc = "0.2"
|
||||
|
||||
# Async runtime
|
||||
tokio = { version = "1.42", features = ["full", "tracing"] }
|
||||
tokio-util = { version = "0.7", features = ["codec"] }
|
||||
tokio-util = { version = "0.7", features = ["full"] }
|
||||
|
||||
# Crypto
|
||||
aes = "0.8"
|
||||
@@ -53,6 +53,7 @@ reqwest = { version = "0.12", features = ["rustls-tls"], default-features = fals
|
||||
hyper = { version = "1", features = ["server", "http1"] }
|
||||
hyper-util = { version = "0.1", features = ["tokio", "server-auto"] }
|
||||
http-body-util = "0.1"
|
||||
httpdate = "1.0"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = "0.4"
|
||||
|
||||
26
src/main.rs
26
src/main.rs
@@ -256,12 +256,24 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
||||
if probe.local_addr.ip() != probe.reflected_addr.ip()
|
||||
&& !config.general.stun_iface_mismatch_ignore
|
||||
{
|
||||
match crate::transport::middle_proxy::detect_public_ip().await {
|
||||
Some(ip) => {
|
||||
info!(
|
||||
local_ip = %probe.local_addr.ip(),
|
||||
reflected_ip = %probe.reflected_addr.ip(),
|
||||
public_ip = %ip,
|
||||
"STUN mismatch but public IP auto-detected, continuing with middle proxy"
|
||||
);
|
||||
}
|
||||
None => {
|
||||
warn!(
|
||||
"STUN/IP-on-Interface mismatch -> fallback to direct-DC"
|
||||
"STUN/IP-on-Interface mismatch and public IP auto-detect failed -> fallback to direct-DC"
|
||||
);
|
||||
use_middle_proxy = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(None) => warn!("STUN probe returned no address; continuing"),
|
||||
Err(e) => warn!(error = %e, "STUN probe failed; continuing"),
|
||||
}
|
||||
@@ -355,6 +367,18 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
|
||||
.await;
|
||||
});
|
||||
|
||||
// Periodic ME connection rotation
|
||||
let pool_clone_rot = pool.clone();
|
||||
let rng_clone_rot = rng.clone();
|
||||
tokio::spawn(async move {
|
||||
crate::transport::middle_proxy::me_rotation_task(
|
||||
pool_clone_rot,
|
||||
rng_clone_rot,
|
||||
std::time::Duration::from_secs(1800),
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
// Periodic updater: getProxyConfig + proxy-secret
|
||||
let pool_clone2 = pool.clone();
|
||||
let rng_clone2 = rng.clone();
|
||||
|
||||
@@ -174,6 +174,7 @@ impl RpcWriter {
|
||||
if buf.len() >= 16 {
|
||||
self.iv.copy_from_slice(&buf[buf.len() - 16..]);
|
||||
}
|
||||
self.writer.write_all(&buf).await.map_err(ProxyError::Io)
|
||||
self.writer.write_all(&buf).await.map_err(ProxyError::Io)?;
|
||||
self.writer.flush().await.map_err(ProxyError::Io)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use regex::Regex;
|
||||
use httpdate;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::error::Result;
|
||||
@@ -11,6 +12,7 @@ use crate::error::Result;
|
||||
use super::MePool;
|
||||
use super::secret::download_proxy_secret;
|
||||
use crate::crypto::SecureRandom;
|
||||
use std::time::SystemTime;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ProxyConfigData {
|
||||
@@ -19,9 +21,29 @@ pub struct ProxyConfigData {
|
||||
}
|
||||
|
||||
pub async fn fetch_proxy_config(url: &str) -> Result<ProxyConfigData> {
|
||||
let text = reqwest::get(url)
|
||||
let resp = reqwest::get(url)
|
||||
.await
|
||||
.map_err(|e| crate::error::ProxyError::Proxy(format!("fetch_proxy_config GET failed: {e}")))?
|
||||
;
|
||||
|
||||
if let Some(date) = resp.headers().get(reqwest::header::DATE) {
|
||||
if let Ok(date_str) = date.to_str() {
|
||||
if let Ok(server_time) = httpdate::parse_http_date(date_str) {
|
||||
if let Ok(skew) = SystemTime::now().duration_since(server_time).or_else(|e| {
|
||||
server_time.duration_since(SystemTime::now()).map_err(|_| e)
|
||||
}) {
|
||||
let skew_secs = skew.as_secs();
|
||||
if skew_secs > 60 {
|
||||
warn!(skew_secs, "Time skew >60s detected from fetch_proxy_config Date header");
|
||||
} else if skew_secs > 30 {
|
||||
warn!(skew_secs, "Time skew >30s detected from fetch_proxy_config Date header");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let text = resp
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| crate::error::ProxyError::Proxy(format!("fetch_proxy_config read failed: {e}")))?;
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::time::{Duration, Instant};
|
||||
use socket2::{SockRef, TcpKeepalive};
|
||||
#[cfg(target_os = "linux")]
|
||||
use libc;
|
||||
#[cfg(target_os = "linux")]
|
||||
use std::os::fd::{AsRawFd, RawFd};
|
||||
#[cfg(target_os = "linux")]
|
||||
use std::os::raw::c_int;
|
||||
|
||||
use bytes::BytesMut;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
|
||||
@@ -41,9 +48,45 @@ impl MePool {
|
||||
.map_err(|_| ProxyError::ConnectionTimeout { addr: addr.to_string() })??;
|
||||
let connect_ms = start.elapsed().as_secs_f64() * 1000.0;
|
||||
stream.set_nodelay(true).ok();
|
||||
if let Err(e) = Self::configure_keepalive(&stream) {
|
||||
warn!(error = %e, "ME keepalive setup failed");
|
||||
}
|
||||
#[cfg(target_os = "linux")]
|
||||
if let Err(e) = Self::configure_user_timeout(stream.as_raw_fd()) {
|
||||
warn!(error = %e, "ME TCP_USER_TIMEOUT setup failed");
|
||||
}
|
||||
Ok((stream, connect_ms))
|
||||
}
|
||||
|
||||
fn configure_keepalive(stream: &TcpStream) -> std::io::Result<()> {
|
||||
let sock = SockRef::from(stream);
|
||||
let ka = TcpKeepalive::new()
|
||||
.with_time(Duration::from_secs(30))
|
||||
.with_interval(Duration::from_secs(10))
|
||||
.with_retries(3);
|
||||
sock.set_tcp_keepalive(&ka)?;
|
||||
sock.set_keepalive(true)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
fn configure_user_timeout(fd: RawFd) -> std::io::Result<()> {
|
||||
let timeout_ms: c_int = 30_000;
|
||||
let rc = unsafe {
|
||||
libc::setsockopt(
|
||||
fd,
|
||||
libc::IPPROTO_TCP,
|
||||
libc::TCP_USER_TIMEOUT,
|
||||
&timeout_ms as *const _ as *const libc::c_void,
|
||||
std::mem::size_of_val(&timeout_ms) as libc::socklen_t,
|
||||
)
|
||||
};
|
||||
if rc != 0 {
|
||||
return Err(std::io::Error::last_os_error());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Perform full ME RPC handshake on an established TCP stream.
|
||||
/// Returns cipher keys/ivs and split halves; does not register writer.
|
||||
pub(crate) async fn handshake_only(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use tracing::{debug, info, warn};
|
||||
use rand::seq::SliceRandom;
|
||||
@@ -10,6 +11,8 @@ use crate::crypto::SecureRandom;
|
||||
use super::MePool;
|
||||
|
||||
pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, _min_connections: usize) {
|
||||
let mut backoff: HashMap<i32, u64> = HashMap::new();
|
||||
let mut last_attempt: HashMap<i32, Instant> = HashMap::new();
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||
// Per-DC coverage check
|
||||
@@ -19,7 +22,7 @@ pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, _min_c
|
||||
.read()
|
||||
.await
|
||||
.iter()
|
||||
.map(|(a, _)| *a)
|
||||
.map(|w| w.addr)
|
||||
.collect();
|
||||
|
||||
for (dc, addrs) in map.iter() {
|
||||
@@ -29,18 +32,81 @@ pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, _min_c
|
||||
.collect();
|
||||
let has_coverage = dc_addrs.iter().any(|a| writer_addrs.contains(a));
|
||||
if !has_coverage {
|
||||
warn!(dc = %dc, "DC has no ME coverage, reconnecting...");
|
||||
let delay = *backoff.get(dc).unwrap_or(&30);
|
||||
let now = Instant::now();
|
||||
if let Some(last) = last_attempt.get(dc) {
|
||||
if now.duration_since(*last).as_secs() < delay {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
warn!(dc = %dc, delay, "DC has no ME coverage, reconnecting...");
|
||||
let mut shuffled = dc_addrs.clone();
|
||||
shuffled.shuffle(&mut rand::rng());
|
||||
let mut reconnected = false;
|
||||
for addr in shuffled {
|
||||
match pool.connect_one(addr, &rng).await {
|
||||
Ok(()) => {
|
||||
info!(%addr, dc = %dc, "ME reconnected for DC coverage");
|
||||
backoff.insert(*dc, 30);
|
||||
last_attempt.insert(*dc, now);
|
||||
reconnected = true;
|
||||
break;
|
||||
}
|
||||
Err(e) => debug!(%addr, dc = %dc, error = %e, "ME reconnect failed"),
|
||||
}
|
||||
}
|
||||
if !reconnected {
|
||||
let next = (*backoff.get(dc).unwrap_or(&30)).saturating_mul(2).min(300);
|
||||
backoff.insert(*dc, next);
|
||||
last_attempt.insert(*dc, now);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IPv6 coverage check (if available)
|
||||
let map_v6 = pool.proxy_map_v6.read().await.clone();
|
||||
let writer_addrs_v6: std::collections::HashSet<SocketAddr> = pool
|
||||
.writers
|
||||
.read()
|
||||
.await
|
||||
.iter()
|
||||
.map(|w| w.addr)
|
||||
.collect();
|
||||
for (dc, addrs) in map_v6.iter() {
|
||||
let dc_addrs: Vec<SocketAddr> = addrs
|
||||
.iter()
|
||||
.map(|(ip, port)| SocketAddr::new(*ip, *port))
|
||||
.collect();
|
||||
let has_coverage = dc_addrs.iter().any(|a| writer_addrs_v6.contains(a));
|
||||
if !has_coverage {
|
||||
let delay = *backoff.get(dc).unwrap_or(&30);
|
||||
let now = Instant::now();
|
||||
if let Some(last) = last_attempt.get(dc) {
|
||||
if now.duration_since(*last).as_secs() < delay {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
warn!(dc = %dc, delay, "IPv6 DC has no ME coverage, reconnecting...");
|
||||
let mut shuffled = dc_addrs.clone();
|
||||
shuffled.shuffle(&mut rand::rng());
|
||||
let mut reconnected = false;
|
||||
for addr in shuffled {
|
||||
match pool.connect_one(addr, &rng).await {
|
||||
Ok(()) => {
|
||||
info!(%addr, dc = %dc, "ME reconnected for IPv6 DC coverage");
|
||||
backoff.insert(*dc, 30);
|
||||
last_attempt.insert(*dc, now);
|
||||
reconnected = true;
|
||||
break;
|
||||
}
|
||||
Err(e) => debug!(%addr, dc = %dc, error = %e, "ME reconnect failed (IPv6)"),
|
||||
}
|
||||
}
|
||||
if !reconnected {
|
||||
let next = (*backoff.get(dc).unwrap_or(&30)).saturating_mul(2).min(300);
|
||||
backoff.insert(*dc, next);
|
||||
last_attempt.insert(*dc, now);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ mod reader;
|
||||
mod registry;
|
||||
mod send;
|
||||
mod secret;
|
||||
mod rotation;
|
||||
mod config_updater;
|
||||
mod wire;
|
||||
|
||||
@@ -18,10 +19,11 @@ use bytes::Bytes;
|
||||
pub use health::me_health_monitor;
|
||||
pub use ping::{run_me_ping, format_sample_line, MePingReport, MePingSample, MePingFamily};
|
||||
pub use pool::MePool;
|
||||
pub use pool_nat::{stun_probe, StunProbeResult};
|
||||
pub use pool_nat::{stun_probe, detect_public_ip, StunProbeResult};
|
||||
pub use registry::ConnRegistry;
|
||||
pub use secret::fetch_proxy_secret;
|
||||
pub use config_updater::{fetch_proxy_config, me_config_updater};
|
||||
pub use rotation::me_rotation_task;
|
||||
pub use wire::proto_flags_for_tag;
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
use std::collections::HashMap;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicI32, AtomicU64};
|
||||
use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, Ordering};
|
||||
use bytes::BytesMut;
|
||||
use rand::Rng;
|
||||
use rand::seq::SliceRandom;
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, info, warn};
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -14,15 +15,26 @@ use crate::error::{ProxyError, Result};
|
||||
use crate::protocol::constants::*;
|
||||
|
||||
use super::ConnRegistry;
|
||||
use super::registry::{BoundConn, ConnMeta};
|
||||
use super::codec::RpcWriter;
|
||||
use super::reader::reader_loop;
|
||||
use super::MeResponse;
|
||||
|
||||
const ME_ACTIVE_PING_SECS: u64 = 25;
|
||||
const ME_ACTIVE_PING_JITTER_SECS: i64 = 5;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MeWriter {
|
||||
pub id: u64,
|
||||
pub addr: SocketAddr,
|
||||
pub writer: Arc<Mutex<RpcWriter>>,
|
||||
pub cancel: CancellationToken,
|
||||
pub degraded: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
pub struct MePool {
|
||||
pub(super) registry: Arc<ConnRegistry>,
|
||||
pub(super) writers: Arc<RwLock<Vec<(SocketAddr, Arc<Mutex<RpcWriter>>)>>> ,
|
||||
pub(super) writers: Arc<RwLock<Vec<MeWriter>>>,
|
||||
pub(super) rr: AtomicU64,
|
||||
pub(super) proxy_tag: Option<Vec<u8>>,
|
||||
pub(super) proxy_secret: Arc<RwLock<Vec<u8>>>,
|
||||
@@ -33,6 +45,10 @@ pub struct MePool {
|
||||
pub(super) proxy_map_v4: Arc<RwLock<HashMap<i32, Vec<(IpAddr, u16)>>>>,
|
||||
pub(super) proxy_map_v6: Arc<RwLock<HashMap<i32, Vec<(IpAddr, u16)>>>>,
|
||||
pub(super) default_dc: AtomicI32,
|
||||
pub(super) next_writer_id: AtomicU64,
|
||||
pub(super) ping_tracker: Arc<Mutex<HashMap<i64, (std::time::Instant, u64)>>>,
|
||||
pub(super) rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
|
||||
pub(super) nat_reflection_cache: Arc<Mutex<Option<(std::time::Instant, std::net::SocketAddr)>>>,
|
||||
pool_size: usize,
|
||||
}
|
||||
|
||||
@@ -61,6 +77,10 @@ impl MePool {
|
||||
proxy_map_v4: Arc::new(RwLock::new(proxy_map_v4)),
|
||||
proxy_map_v6: Arc::new(RwLock::new(proxy_map_v6)),
|
||||
default_dc: AtomicI32::new(default_dc.unwrap_or(0)),
|
||||
next_writer_id: AtomicU64::new(1),
|
||||
ping_tracker: Arc::new(Mutex::new(HashMap::new())),
|
||||
rtt_stats: Arc::new(Mutex::new(HashMap::new())),
|
||||
nat_reflection_cache: Arc::new(Mutex::new(None)),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -77,16 +97,19 @@ impl MePool {
|
||||
&self.registry
|
||||
}
|
||||
|
||||
fn writers_arc(&self) -> Arc<RwLock<Vec<(SocketAddr, Arc<Mutex<RpcWriter>>)>>>
|
||||
{
|
||||
fn writers_arc(&self) -> Arc<RwLock<Vec<MeWriter>>> {
|
||||
self.writers.clone()
|
||||
}
|
||||
|
||||
pub async fn reconcile_connections(&self, rng: &SecureRandom) {
|
||||
pub async fn reconcile_connections(self: &Arc<Self>, rng: &SecureRandom) {
|
||||
use std::collections::HashSet;
|
||||
let map = self.proxy_map_v4.read().await.clone();
|
||||
let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map
|
||||
.iter()
|
||||
.map(|(dc, addrs)| (*dc, addrs.clone()))
|
||||
.collect();
|
||||
let writers = self.writers.read().await;
|
||||
let current: HashSet<SocketAddr> = writers.iter().map(|(a, _)| *a).collect();
|
||||
let current: HashSet<SocketAddr> = writers.iter().map(|w| w.addr).collect();
|
||||
drop(writers);
|
||||
|
||||
for (_dc, addrs) in map.iter() {
|
||||
@@ -158,8 +181,12 @@ impl MePool {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn init(self: &Arc<Self>, pool_size: usize, rng: &SecureRandom) -> Result<()> {
|
||||
let map = self.proxy_map_v4.read().await;
|
||||
pub async fn init(self: &Arc<Self>, pool_size: usize, rng: &Arc<SecureRandom>) -> Result<()> {
|
||||
let map = self.proxy_map_v4.read().await.clone();
|
||||
let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map
|
||||
.iter()
|
||||
.map(|(dc, addrs)| (*dc, addrs.clone()))
|
||||
.collect();
|
||||
let ks = self.key_selector().await;
|
||||
info!(
|
||||
me_servers = map.len(),
|
||||
@@ -169,38 +196,28 @@ impl MePool {
|
||||
"Initializing ME pool"
|
||||
);
|
||||
|
||||
// Ensure at least one connection per DC with failover over all addresses
|
||||
for (dc, addrs) in map.iter() {
|
||||
// Ensure at least one connection per DC; run DCs in parallel.
|
||||
let mut join = tokio::task::JoinSet::new();
|
||||
for (dc, addrs) in dc_addrs.iter().cloned() {
|
||||
if addrs.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let mut connected = false;
|
||||
let mut shuffled = addrs.clone();
|
||||
shuffled.shuffle(&mut rand::rng());
|
||||
for (ip, port) in shuffled {
|
||||
let addr = SocketAddr::new(ip, port);
|
||||
match self.connect_one(addr, rng).await {
|
||||
Ok(()) => {
|
||||
info!(%addr, dc = %dc, "ME connected");
|
||||
connected = true;
|
||||
break;
|
||||
}
|
||||
Err(e) => warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next"),
|
||||
}
|
||||
}
|
||||
if !connected {
|
||||
warn!(dc = %dc, "All ME servers for DC failed at init");
|
||||
}
|
||||
let pool = Arc::clone(self);
|
||||
let rng_clone = Arc::clone(rng);
|
||||
join.spawn(async move {
|
||||
pool.connect_primary_for_dc(dc, addrs, rng_clone).await;
|
||||
});
|
||||
}
|
||||
while let Some(_res) = join.join_next().await {}
|
||||
|
||||
// Additional connections up to pool_size total (round-robin across DCs)
|
||||
for (dc, addrs) in map.iter() {
|
||||
for (dc, addrs) in dc_addrs.iter() {
|
||||
for (ip, port) in addrs {
|
||||
if self.connection_count() >= pool_size {
|
||||
break;
|
||||
}
|
||||
let addr = SocketAddr::new(*ip, *port);
|
||||
if let Err(e) = self.connect_one(addr, rng).await {
|
||||
if let Err(e) = self.connect_one(addr, rng.as_ref()).await {
|
||||
debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed");
|
||||
}
|
||||
}
|
||||
@@ -215,7 +232,7 @@ impl MePool {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn connect_one(&self, addr: SocketAddr, rng: &SecureRandom) -> Result<()> {
|
||||
pub(crate) async fn connect_one(self: &Arc<Self>, addr: SocketAddr, rng: &SecureRandom) -> Result<()> {
|
||||
let secret_len = self.proxy_secret.read().await.len();
|
||||
if secret_len < 32 {
|
||||
return Err(ProxyError::Proxy("proxy-secret too short for ME auth".into()));
|
||||
@@ -224,44 +241,88 @@ impl MePool {
|
||||
let (stream, _connect_ms) = self.connect_tcp(addr).await?;
|
||||
let hs = self.handshake_only(stream, addr, rng).await?;
|
||||
|
||||
let writer_id = self.next_writer_id.fetch_add(1, Ordering::Relaxed);
|
||||
let cancel = CancellationToken::new();
|
||||
let degraded = Arc::new(AtomicBool::new(false));
|
||||
let rpc_w = Arc::new(Mutex::new(RpcWriter {
|
||||
writer: hs.wr,
|
||||
key: hs.write_key,
|
||||
iv: hs.write_iv,
|
||||
seq_no: 0,
|
||||
}));
|
||||
self.writers.write().await.push((addr, rpc_w.clone()));
|
||||
let writer = MeWriter {
|
||||
id: writer_id,
|
||||
addr,
|
||||
writer: rpc_w.clone(),
|
||||
cancel: cancel.clone(),
|
||||
degraded: degraded.clone(),
|
||||
};
|
||||
self.writers.write().await.push(writer.clone());
|
||||
|
||||
let reg = self.registry.clone();
|
||||
let w_pong = rpc_w.clone();
|
||||
let w_pool = self.writers_arc();
|
||||
let w_ping = rpc_w.clone();
|
||||
let w_pool_ping = self.writers_arc();
|
||||
let writers_arc = self.writers_arc();
|
||||
let ping_tracker = self.ping_tracker.clone();
|
||||
let rtt_stats = self.rtt_stats.clone();
|
||||
let pool = Arc::downgrade(self);
|
||||
let cancel_ping = cancel.clone();
|
||||
let rpc_w_ping = rpc_w.clone();
|
||||
let ping_tracker_ping = ping_tracker.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) =
|
||||
reader_loop(hs.rd, hs.read_key, hs.read_iv, reg, BytesMut::new(), BytesMut::new(), w_pong.clone()).await
|
||||
{
|
||||
let cancel_reader = cancel.clone();
|
||||
let res = reader_loop(
|
||||
hs.rd,
|
||||
hs.read_key,
|
||||
hs.read_iv,
|
||||
reg.clone(),
|
||||
BytesMut::new(),
|
||||
BytesMut::new(),
|
||||
rpc_w.clone(),
|
||||
ping_tracker.clone(),
|
||||
rtt_stats.clone(),
|
||||
writer_id,
|
||||
degraded.clone(),
|
||||
cancel_reader.clone(),
|
||||
)
|
||||
.await;
|
||||
if let Some(pool) = pool.upgrade() {
|
||||
pool.remove_writer_and_reroute(writer_id).await;
|
||||
}
|
||||
if let Err(e) = res {
|
||||
warn!(error = %e, "ME reader ended");
|
||||
}
|
||||
let mut ws = w_pool.write().await;
|
||||
ws.retain(|(_, w)| !Arc::ptr_eq(w, &w_pong));
|
||||
let mut ws = writers_arc.write().await;
|
||||
ws.retain(|w| w.id != writer_id);
|
||||
info!(remaining = ws.len(), "Dead ME writer removed from pool");
|
||||
});
|
||||
|
||||
let pool_ping = Arc::downgrade(self);
|
||||
tokio::spawn(async move {
|
||||
let mut ping_id: i64 = rand::random::<i64>();
|
||||
loop {
|
||||
let jitter = rand::rng()
|
||||
.random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS);
|
||||
let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64;
|
||||
tokio::time::sleep(Duration::from_secs(wait)).await;
|
||||
tokio::select! {
|
||||
_ = cancel_ping.cancelled() => {
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(Duration::from_secs(wait)) => {}
|
||||
}
|
||||
let mut p = Vec::with_capacity(12);
|
||||
p.extend_from_slice(&RPC_PING_U32.to_le_bytes());
|
||||
p.extend_from_slice(&ping_id.to_le_bytes());
|
||||
ping_id = ping_id.wrapping_add(1);
|
||||
if let Err(e) = w_ping.lock().await.send(&p).await {
|
||||
{
|
||||
let mut tracker = ping_tracker_ping.lock().await;
|
||||
tracker.insert(ping_id, (std::time::Instant::now(), writer_id));
|
||||
}
|
||||
if let Err(e) = rpc_w_ping.lock().await.send(&p).await {
|
||||
debug!(error = %e, "Active ME ping failed, removing dead writer");
|
||||
let mut ws = w_pool_ping.write().await;
|
||||
ws.retain(|(_, w)| !Arc::ptr_eq(w, &w_ping));
|
||||
cancel_ping.cancel();
|
||||
if let Some(pool) = pool_ping.upgrade() {
|
||||
pool.remove_writer_and_reroute(writer_id).await;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -270,6 +331,124 @@ impl MePool {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn connect_primary_for_dc(
|
||||
self: Arc<Self>,
|
||||
dc: i32,
|
||||
mut addrs: Vec<(IpAddr, u16)>,
|
||||
rng: Arc<SecureRandom>,
|
||||
) {
|
||||
if addrs.is_empty() {
|
||||
return;
|
||||
}
|
||||
addrs.shuffle(&mut rand::rng());
|
||||
for (ip, port) in addrs {
|
||||
let addr = SocketAddr::new(ip, port);
|
||||
match self.connect_one(addr, rng.as_ref()).await {
|
||||
Ok(()) => {
|
||||
info!(%addr, dc = %dc, "ME connected");
|
||||
return;
|
||||
}
|
||||
Err(e) => warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next"),
|
||||
}
|
||||
}
|
||||
warn!(dc = %dc, "All ME servers for DC failed at init");
|
||||
}
|
||||
|
||||
pub(crate) async fn remove_writer_and_reroute(&self, writer_id: u64) {
|
||||
let mut queue = self.remove_writer_only(writer_id).await;
|
||||
while let Some(bound) = queue.pop() {
|
||||
if !self.reroute_conn(&bound, &mut queue).await {
|
||||
let _ = self.registry.route(bound.conn_id, super::MeResponse::Close).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn remove_writer_only(&self, writer_id: u64) -> Vec<BoundConn> {
|
||||
{
|
||||
let mut ws = self.writers.write().await;
|
||||
if let Some(pos) = ws.iter().position(|w| w.id == writer_id) {
|
||||
let w = ws.remove(pos);
|
||||
w.cancel.cancel();
|
||||
}
|
||||
}
|
||||
self.registry.writer_lost(writer_id).await
|
||||
}
|
||||
|
||||
async fn reroute_conn(&self, bound: &BoundConn, backlog: &mut Vec<BoundConn>) -> bool {
|
||||
let payload = super::wire::build_proxy_req_payload(
|
||||
bound.conn_id,
|
||||
bound.meta.client_addr,
|
||||
bound.meta.our_addr,
|
||||
&[],
|
||||
self.proxy_tag.as_deref(),
|
||||
bound.meta.proto_flags,
|
||||
);
|
||||
|
||||
let mut attempts = 0;
|
||||
loop {
|
||||
let writers_snapshot = {
|
||||
let ws = self.writers.read().await;
|
||||
if ws.is_empty() {
|
||||
return false;
|
||||
}
|
||||
ws.clone()
|
||||
};
|
||||
let mut candidates = self.candidate_indices_for_dc(&writers_snapshot, bound.meta.target_dc).await;
|
||||
if candidates.is_empty() {
|
||||
return false;
|
||||
}
|
||||
candidates.sort_by_key(|idx| {
|
||||
writers_snapshot[*idx]
|
||||
.degraded
|
||||
.load(Ordering::Relaxed)
|
||||
.then_some(1usize)
|
||||
.unwrap_or(0)
|
||||
});
|
||||
let start = self.rr.fetch_add(1, Ordering::Relaxed) as usize % candidates.len();
|
||||
|
||||
for offset in 0..candidates.len() {
|
||||
let idx = candidates[(start + offset) % candidates.len()];
|
||||
let w = &writers_snapshot[idx];
|
||||
if let Ok(mut guard) = w.writer.try_lock() {
|
||||
let send_res = guard.send(&payload).await;
|
||||
drop(guard);
|
||||
match send_res {
|
||||
Ok(()) => {
|
||||
self.registry
|
||||
.bind_writer(bound.conn_id, w.id, w.writer.clone(), bound.meta.clone())
|
||||
.await;
|
||||
return true;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, writer_id = w.id, "ME reroute send failed");
|
||||
backlog.extend(self.remove_writer_only(w.id).await);
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let w = writers_snapshot[candidates[start]].clone();
|
||||
match w.writer.lock().await.send(&payload).await {
|
||||
Ok(()) => {
|
||||
self.registry
|
||||
.bind_writer(bound.conn_id, w.id, w.writer.clone(), bound.meta.clone())
|
||||
.await;
|
||||
return true;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, writer_id = w.id, "ME reroute send failed (blocking)");
|
||||
backlog.extend(self.remove_writer_only(w.id).await);
|
||||
}
|
||||
}
|
||||
|
||||
attempts += 1;
|
||||
if attempts > 3 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
fn hex_dump(data: &[u8]) -> String {
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
use std::time::Duration;
|
||||
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::error::{ProxyError, Result};
|
||||
|
||||
use super::MePool;
|
||||
use std::time::Instant;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct StunProbeResult {
|
||||
@@ -17,6 +19,10 @@ pub async fn stun_probe(stun_addr: Option<String>) -> Result<Option<StunProbeRes
|
||||
fetch_stun_binding(&stun_addr).await
|
||||
}
|
||||
|
||||
pub async fn detect_public_ip() -> Option<IpAddr> {
|
||||
fetch_public_ipv4_with_retry().await.ok().flatten().map(IpAddr::V4)
|
||||
}
|
||||
|
||||
impl MePool {
|
||||
pub(super) fn translate_ip_for_nat(&self, ip: IpAddr) -> IpAddr {
|
||||
let nat_ip = self
|
||||
@@ -93,6 +99,15 @@ impl MePool {
|
||||
}
|
||||
|
||||
pub(super) async fn maybe_reflect_public_addr(&self) -> Option<std::net::SocketAddr> {
|
||||
const STUN_CACHE_TTL: Duration = Duration::from_secs(600);
|
||||
if let Ok(mut cache) = self.nat_reflection_cache.try_lock() {
|
||||
if let Some((ts, addr)) = *cache {
|
||||
if ts.elapsed() < STUN_CACHE_TTL {
|
||||
return Some(addr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let stun_addr = self
|
||||
.nat_stun
|
||||
.clone()
|
||||
@@ -101,6 +116,9 @@ impl MePool {
|
||||
Ok(sa) => {
|
||||
if let Some(result) = sa {
|
||||
info!(local = %result.local_addr, reflected = %result.reflected_addr, "NAT probe: reflected address");
|
||||
if let Ok(mut cache) = self.nat_reflection_cache.try_lock() {
|
||||
*cache = Some((Instant::now(), result.reflected_addr));
|
||||
}
|
||||
Some(result.reflected_addr)
|
||||
} else {
|
||||
None
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::time::Instant;
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, trace, warn};
|
||||
|
||||
use crate::crypto::{AesCbc, crc32};
|
||||
@@ -21,12 +25,21 @@ pub(crate) async fn reader_loop(
|
||||
enc_leftover: BytesMut,
|
||||
mut dec: BytesMut,
|
||||
writer: Arc<Mutex<RpcWriter>>,
|
||||
ping_tracker: Arc<Mutex<HashMap<i64, (Instant, u64)>>>,
|
||||
rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
|
||||
_writer_id: u64,
|
||||
degraded: Arc<AtomicBool>,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<()> {
|
||||
let mut raw = enc_leftover;
|
||||
let mut expected_seq: i32 = 0;
|
||||
|
||||
loop {
|
||||
let mut tmp = [0u8; 16_384];
|
||||
let n = rd.read(&mut tmp).await.map_err(ProxyError::Io)?;
|
||||
let n = tokio::select! {
|
||||
res = rd.read(&mut tmp) => res.map_err(ProxyError::Io)?,
|
||||
_ = cancel.cancelled() => return Ok(()),
|
||||
};
|
||||
if n == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
@@ -70,6 +83,14 @@ pub(crate) async fn reader_loop(
|
||||
continue;
|
||||
}
|
||||
|
||||
let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap());
|
||||
if seq_no != expected_seq {
|
||||
warn!(seq_no, expected = expected_seq, "ME RPC seq mismatch");
|
||||
expected_seq = seq_no.wrapping_add(1);
|
||||
} else {
|
||||
expected_seq = expected_seq.wrapping_add(1);
|
||||
}
|
||||
|
||||
let payload = &frame[8..pe];
|
||||
if payload.len() < 4 {
|
||||
continue;
|
||||
@@ -119,6 +140,23 @@ pub(crate) async fn reader_loop(
|
||||
warn!(error = %e, "PONG send failed");
|
||||
break;
|
||||
}
|
||||
} else if pt == RPC_PONG_U32 && body.len() >= 8 {
|
||||
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
|
||||
if let Some((sent, wid)) = {
|
||||
let mut guard = ping_tracker.lock().await;
|
||||
guard.remove(&ping_id)
|
||||
} {
|
||||
let rtt = sent.elapsed().as_secs_f64() * 1000.0;
|
||||
let mut stats = rtt_stats.lock().await;
|
||||
let entry = stats.entry(wid).or_insert((rtt, rtt));
|
||||
entry.1 = entry.1 * 0.8 + rtt * 0.2;
|
||||
if rtt < entry.0 {
|
||||
entry.0 = rtt;
|
||||
}
|
||||
let degraded_now = entry.1 > entry.0 * 2.0;
|
||||
degraded.store(degraded_now, Ordering::Relaxed);
|
||||
trace!(writer_id = wid, rtt_ms = rtt, ema_ms = entry.1, base_ms = entry.0, degraded = degraded_now, "ME RTT sample");
|
||||
}
|
||||
} else {
|
||||
debug!(
|
||||
rpc_type = format_args!("0x{pt:08x}"),
|
||||
|
||||
@@ -1,60 +1,133 @@
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
use tokio::sync::{RwLock, mpsc};
|
||||
|
||||
use super::MeResponse;
|
||||
use super::codec::RpcWriter;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use tokio::sync::{mpsc, Mutex, RwLock};
|
||||
|
||||
use super::codec::RpcWriter;
|
||||
use super::MeResponse;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ConnMeta {
|
||||
pub target_dc: i16,
|
||||
pub client_addr: SocketAddr,
|
||||
pub our_addr: SocketAddr,
|
||||
pub proto_flags: u32,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct BoundConn {
|
||||
pub conn_id: u64,
|
||||
pub meta: ConnMeta,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ConnWriter {
|
||||
pub writer_id: u64,
|
||||
pub writer: Arc<Mutex<RpcWriter>>,
|
||||
}
|
||||
|
||||
pub struct ConnRegistry {
|
||||
map: RwLock<HashMap<u64, mpsc::UnboundedSender<MeResponse>>>,
|
||||
map: RwLock<HashMap<u64, mpsc::Sender<MeResponse>>>,
|
||||
writers: RwLock<HashMap<u64, Arc<Mutex<RpcWriter>>>>,
|
||||
writer_for_conn: RwLock<HashMap<u64, u64>>,
|
||||
conns_for_writer: RwLock<HashMap<u64, Vec<u64>>>,
|
||||
meta: RwLock<HashMap<u64, ConnMeta>>,
|
||||
next_id: AtomicU64,
|
||||
}
|
||||
|
||||
impl ConnRegistry {
|
||||
pub fn new() -> Self {
|
||||
// Avoid fully predictable conn_id sequence from 1.
|
||||
let start = rand::random::<u64>() | 1;
|
||||
Self {
|
||||
map: RwLock::new(HashMap::new()),
|
||||
writers: RwLock::new(HashMap::new()),
|
||||
writer_for_conn: RwLock::new(HashMap::new()),
|
||||
conns_for_writer: RwLock::new(HashMap::new()),
|
||||
meta: RwLock::new(HashMap::new()),
|
||||
next_id: AtomicU64::new(start),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn register(&self) -> (u64, mpsc::UnboundedReceiver<MeResponse>) {
|
||||
pub async fn register(&self) -> (u64, mpsc::Receiver<MeResponse>) {
|
||||
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
|
||||
// Unbounded per-connection queue prevents reader-loop HOL blocking on
|
||||
// slow clients: routing stays non-blocking and preserves message order.
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
let (tx, rx) = mpsc::channel(1024);
|
||||
self.map.write().await.insert(id, tx);
|
||||
(id, rx)
|
||||
}
|
||||
|
||||
pub async fn unregister(&self, id: u64) {
|
||||
self.map.write().await.remove(&id);
|
||||
self.writers.write().await.remove(&id);
|
||||
self.meta.write().await.remove(&id);
|
||||
if let Some(writer_id) = self.writer_for_conn.write().await.remove(&id) {
|
||||
if let Some(list) = self.conns_for_writer.write().await.get_mut(&writer_id) {
|
||||
list.retain(|c| *c != id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn route(&self, id: u64, resp: MeResponse) -> bool {
|
||||
let m = self.map.read().await;
|
||||
if let Some(tx) = m.get(&id) {
|
||||
tx.send(resp).is_ok()
|
||||
tx.try_send(resp).is_ok()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn set_writer(&self, id: u64, w: Arc<Mutex<RpcWriter>>) {
|
||||
let mut guard = self.writers.write().await;
|
||||
guard.entry(id).or_insert_with(|| w);
|
||||
pub async fn bind_writer(
|
||||
&self,
|
||||
conn_id: u64,
|
||||
writer_id: u64,
|
||||
writer: Arc<Mutex<RpcWriter>>,
|
||||
meta: ConnMeta,
|
||||
) {
|
||||
self.meta.write().await.entry(conn_id).or_insert(meta);
|
||||
self.writer_for_conn.write().await.insert(conn_id, writer_id);
|
||||
self.writers.write().await.entry(writer_id).or_insert_with(|| writer.clone());
|
||||
self.conns_for_writer
|
||||
.write()
|
||||
.await
|
||||
.entry(writer_id)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(conn_id);
|
||||
}
|
||||
|
||||
pub async fn get_writer(&self, id: u64) -> Option<Arc<Mutex<RpcWriter>>> {
|
||||
pub async fn get_writer(&self, conn_id: u64) -> Option<ConnWriter> {
|
||||
let writer_id = {
|
||||
let guard = self.writer_for_conn.read().await;
|
||||
guard.get(&conn_id).cloned()
|
||||
}?;
|
||||
let writer = {
|
||||
let guard = self.writers.read().await;
|
||||
guard.get(&id).cloned()
|
||||
guard.get(&writer_id).cloned()
|
||||
}?;
|
||||
Some(ConnWriter { writer_id, writer })
|
||||
}
|
||||
|
||||
pub async fn writer_lost(&self, writer_id: u64) -> Vec<BoundConn> {
|
||||
self.writers.write().await.remove(&writer_id);
|
||||
let conns = self.conns_for_writer.write().await.remove(&writer_id).unwrap_or_default();
|
||||
|
||||
let mut out = Vec::new();
|
||||
let mut writer_for_conn = self.writer_for_conn.write().await;
|
||||
let meta = self.meta.read().await;
|
||||
|
||||
for conn_id in conns {
|
||||
writer_for_conn.remove(&conn_id);
|
||||
if let Some(m) = meta.get(&conn_id) {
|
||||
out.push(BoundConn {
|
||||
conn_id,
|
||||
meta: m.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
pub async fn get_meta(&self, conn_id: u64) -> Option<ConnMeta> {
|
||||
let guard = self.meta.read().await;
|
||||
guard.get(&conn_id).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
37
src/transport/middle_proxy/rotation.rs
Normal file
37
src/transport/middle_proxy/rotation.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::crypto::SecureRandom;
|
||||
|
||||
use super::MePool;
|
||||
|
||||
/// Periodically refresh ME connections to avoid long-lived degradation.
|
||||
pub async fn me_rotation_task(pool: Arc<MePool>, rng: Arc<SecureRandom>, interval: Duration) {
|
||||
let interval = interval.max(Duration::from_secs(600));
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
|
||||
let candidate = {
|
||||
let ws = pool.writers.read().await;
|
||||
ws.get(0).cloned()
|
||||
};
|
||||
|
||||
let Some(w) = candidate else {
|
||||
continue;
|
||||
};
|
||||
|
||||
info!(addr = %w.addr, writer_id = w.id, "Rotating ME connection");
|
||||
match pool.connect_one(w.addr, rng.as_ref()).await {
|
||||
Ok(()) => {
|
||||
// Remove old writer after new one is up.
|
||||
pool.remove_writer_and_reroute(w.id).await;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(addr = %w.addr, writer_id = w.id, error = %e, "ME rotation connect failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use tracing::{debug, info, warn};
|
||||
use std::time::SystemTime;
|
||||
use httpdate;
|
||||
|
||||
use crate::error::{ProxyError, Result};
|
||||
|
||||
@@ -63,6 +65,23 @@ pub async fn download_proxy_secret() -> Result<Vec<u8>> {
|
||||
)));
|
||||
}
|
||||
|
||||
if let Some(date) = resp.headers().get(reqwest::header::DATE) {
|
||||
if let Ok(date_str) = date.to_str() {
|
||||
if let Ok(server_time) = httpdate::parse_http_date(date_str) {
|
||||
if let Ok(skew) = SystemTime::now().duration_since(server_time).or_else(|e| {
|
||||
server_time.duration_since(SystemTime::now()).map_err(|_| e)
|
||||
}) {
|
||||
let skew_secs = skew.as_secs();
|
||||
if skew_secs > 60 {
|
||||
warn!(skew_secs, "Time skew >60s detected from proxy-secret Date header");
|
||||
} else if skew_secs > 30 {
|
||||
warn!(skew_secs, "Time skew >30s detected from proxy-secret Date header");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let data = resp
|
||||
.bytes()
|
||||
.await
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{debug, warn};
|
||||
@@ -9,14 +10,14 @@ use crate::error::{ProxyError, Result};
|
||||
use crate::protocol::constants::RPC_CLOSE_EXT_U32;
|
||||
|
||||
use super::MePool;
|
||||
use super::codec::RpcWriter;
|
||||
use super::wire::build_proxy_req_payload;
|
||||
use crate::crypto::SecureRandom;
|
||||
use rand::seq::SliceRandom;
|
||||
use super::registry::ConnMeta;
|
||||
|
||||
impl MePool {
|
||||
pub async fn send_proxy_req(
|
||||
&self,
|
||||
self: &Arc<Self>,
|
||||
conn_id: u64,
|
||||
target_dc: i16,
|
||||
client_addr: SocketAddr,
|
||||
@@ -32,18 +33,50 @@ impl MePool {
|
||||
self.proxy_tag.as_deref(),
|
||||
proto_flags,
|
||||
);
|
||||
let meta = ConnMeta {
|
||||
target_dc,
|
||||
client_addr,
|
||||
our_addr,
|
||||
proto_flags,
|
||||
};
|
||||
let mut emergency_attempts = 0;
|
||||
|
||||
loop {
|
||||
if let Some(current) = self.registry.get_writer(conn_id).await {
|
||||
let send_res = {
|
||||
if let Ok(mut guard) = current.writer.try_lock() {
|
||||
let r = guard.send(&payload).await;
|
||||
drop(guard);
|
||||
r
|
||||
} else {
|
||||
current.writer.lock().await.send(&payload).await
|
||||
}
|
||||
};
|
||||
match send_res {
|
||||
Ok(()) => return Ok(()),
|
||||
Err(e) => {
|
||||
warn!(error = %e, writer_id = current.writer_id, "ME write failed");
|
||||
self.remove_writer_and_reroute(current.writer_id).await;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut writers_snapshot = {
|
||||
let ws = self.writers.read().await;
|
||||
if ws.is_empty() {
|
||||
return Err(ProxyError::Proxy("All ME connections dead".into()));
|
||||
}
|
||||
let writers: Vec<(SocketAddr, Arc<Mutex<RpcWriter>>)> = ws.iter().cloned().collect();
|
||||
drop(ws);
|
||||
ws.clone()
|
||||
};
|
||||
|
||||
let mut candidate_indices = self.candidate_indices_for_dc(&writers, target_dc).await;
|
||||
let mut candidate_indices = self.candidate_indices_for_dc(&writers_snapshot, target_dc).await;
|
||||
if candidate_indices.is_empty() {
|
||||
// Emergency: try to connect to target DC addresses on the fly, then recompute writers
|
||||
// Emergency connect-on-demand
|
||||
if emergency_attempts >= 3 {
|
||||
return Err(ProxyError::Proxy("No ME writers available for target DC".into()));
|
||||
}
|
||||
emergency_attempts += 1;
|
||||
let map = self.proxy_map_v4.read().await;
|
||||
if let Some(addrs) = map.get(&(target_dc as i32)) {
|
||||
let mut shuffled = addrs.clone();
|
||||
@@ -55,65 +88,73 @@ impl MePool {
|
||||
break;
|
||||
}
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(100 * emergency_attempts)).await;
|
||||
let ws2 = self.writers.read().await;
|
||||
let writers: Vec<(SocketAddr, Arc<Mutex<RpcWriter>>)> = ws2.iter().cloned().collect();
|
||||
writers_snapshot = ws2.clone();
|
||||
drop(ws2);
|
||||
candidate_indices = self.candidate_indices_for_dc(&writers, target_dc).await;
|
||||
candidate_indices = self.candidate_indices_for_dc(&writers_snapshot, target_dc).await;
|
||||
}
|
||||
if candidate_indices.is_empty() {
|
||||
return Err(ProxyError::Proxy("No ME writers available for target DC".into()));
|
||||
}
|
||||
}
|
||||
|
||||
candidate_indices.sort_by_key(|idx| {
|
||||
writers_snapshot[*idx]
|
||||
.degraded
|
||||
.load(Ordering::Relaxed)
|
||||
.then_some(1usize)
|
||||
.unwrap_or(0)
|
||||
});
|
||||
|
||||
let start = self.rr.fetch_add(1, Ordering::Relaxed) as usize % candidate_indices.len();
|
||||
|
||||
// Prefer immediately available writer to avoid waiting on stalled connection.
|
||||
for offset in 0..candidate_indices.len() {
|
||||
let cidx = (start + offset) % candidate_indices.len();
|
||||
let idx = candidate_indices[cidx];
|
||||
let w = writers[idx].1.clone();
|
||||
if let Ok(mut guard) = w.try_lock() {
|
||||
let idx = candidate_indices[(start + offset) % candidate_indices.len()];
|
||||
let w = &writers_snapshot[idx];
|
||||
if let Ok(mut guard) = w.writer.try_lock() {
|
||||
let send_res = guard.send(&payload).await;
|
||||
drop(guard);
|
||||
match send_res {
|
||||
Ok(()) => return Ok(()),
|
||||
Err(e) => {
|
||||
warn!(error = %e, "ME write failed, removing dead conn");
|
||||
let mut ws = self.writers.write().await;
|
||||
ws.retain(|(_, o)| !Arc::ptr_eq(o, &w));
|
||||
if ws.is_empty() {
|
||||
return Err(ProxyError::Proxy("All ME connections dead".into()));
|
||||
Ok(()) => {
|
||||
self.registry
|
||||
.bind_writer(conn_id, w.id, w.writer.clone(), meta.clone())
|
||||
.await;
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, writer_id = w.id, "ME write failed");
|
||||
self.remove_writer_and_reroute(w.id).await;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// All writers are currently busy, wait for the selected one.
|
||||
let w = writers[candidate_indices[start]].1.clone();
|
||||
match w.lock().await.send(&payload).await {
|
||||
Ok(()) => return Ok(()),
|
||||
Err(e) => {
|
||||
warn!(error = %e, "ME write failed, removing dead conn");
|
||||
let mut ws = self.writers.write().await;
|
||||
ws.retain(|(_, o)| !Arc::ptr_eq(o, &w));
|
||||
if ws.is_empty() {
|
||||
return Err(ProxyError::Proxy("All ME connections dead".into()));
|
||||
let w = writers_snapshot[candidate_indices[start]].clone();
|
||||
match w.writer.lock().await.send(&payload).await {
|
||||
Ok(()) => {
|
||||
self.registry
|
||||
.bind_writer(conn_id, w.id, w.writer.clone(), meta.clone())
|
||||
.await;
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, writer_id = w.id, "ME write failed (blocking)");
|
||||
self.remove_writer_and_reroute(w.id).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_close(&self, conn_id: u64) -> Result<()> {
|
||||
pub async fn send_close(self: &Arc<Self>, conn_id: u64) -> Result<()> {
|
||||
if let Some(w) = self.registry.get_writer(conn_id).await {
|
||||
let mut p = Vec::with_capacity(12);
|
||||
p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes());
|
||||
p.extend_from_slice(&conn_id.to_le_bytes());
|
||||
if let Err(e) = w.lock().await.send(&p).await {
|
||||
if let Err(e) = w.writer.lock().await.send(&p).await {
|
||||
debug!(error = %e, "ME close write failed");
|
||||
let mut ws = self.writers.write().await;
|
||||
ws.retain(|(_, o)| !Arc::ptr_eq(o, &w));
|
||||
self.remove_writer_and_reroute(w.writer_id).await;
|
||||
}
|
||||
} else {
|
||||
debug!(conn_id, "ME close skipped (writer missing)");
|
||||
@@ -129,7 +170,7 @@ impl MePool {
|
||||
|
||||
pub(super) async fn candidate_indices_for_dc(
|
||||
&self,
|
||||
writers: &[(SocketAddr, Arc<Mutex<RpcWriter>>)],
|
||||
writers: &[super::pool::MeWriter],
|
||||
target_dc: i16,
|
||||
) -> Vec<usize> {
|
||||
let mut preferred = Vec::<SocketAddr>::new();
|
||||
@@ -165,8 +206,8 @@ impl MePool {
|
||||
}
|
||||
|
||||
let mut out = Vec::new();
|
||||
for (idx, (addr, _)) in writers.iter().enumerate() {
|
||||
if preferred.iter().any(|p| p == addr) {
|
||||
for (idx, w) in writers.iter().enumerate() {
|
||||
if preferred.iter().any(|p| *p == w.addr) {
|
||||
out.push(idx);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user