Fixed critical ME Problems

This commit is contained in:
Alexey
2026-02-17 03:40:39 +03:00
parent 8bd02d8099
commit 168fd59187
7 changed files with 459 additions and 103 deletions

View File

@@ -9,7 +9,7 @@ libc = "0.2"
# Async runtime
tokio = { version = "1.42", features = ["full", "tracing"] }
tokio-util = { version = "0.7", features = ["codec"] }
tokio-util = { version = "0.7", features = ["full"] }
# Crypto
aes = "0.8"

View File

@@ -1,5 +1,12 @@
use std::net::{IpAddr, SocketAddr};
use std::time::{Duration, Instant};
use socket2::{SockRef, TcpKeepalive};
#[cfg(target_os = "linux")]
use libc;
#[cfg(target_os = "linux")]
use std::os::fd::{AsRawFd, RawFd};
#[cfg(target_os = "linux")]
use std::os::raw::c_int;
use bytes::BytesMut;
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
@@ -41,9 +48,45 @@ impl MePool {
.map_err(|_| ProxyError::ConnectionTimeout { addr: addr.to_string() })??;
let connect_ms = start.elapsed().as_secs_f64() * 1000.0;
stream.set_nodelay(true).ok();
if let Err(e) = Self::configure_keepalive(&stream) {
warn!(error = %e, "ME keepalive setup failed");
}
#[cfg(target_os = "linux")]
if let Err(e) = Self::configure_user_timeout(stream.as_raw_fd()) {
warn!(error = %e, "ME TCP_USER_TIMEOUT setup failed");
}
Ok((stream, connect_ms))
}
fn configure_keepalive(stream: &TcpStream) -> std::io::Result<()> {
let sock = SockRef::from(stream);
let ka = TcpKeepalive::new()
.with_time(Duration::from_secs(30))
.with_interval(Duration::from_secs(10))
.with_retries(3);
sock.set_tcp_keepalive(&ka)?;
sock.set_keepalive(true)?;
Ok(())
}
#[cfg(target_os = "linux")]
fn configure_user_timeout(fd: RawFd) -> std::io::Result<()> {
let timeout_ms: c_int = 30_000;
let rc = unsafe {
libc::setsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_USER_TIMEOUT,
&timeout_ms as *const _ as *const libc::c_void,
std::mem::size_of_val(&timeout_ms) as libc::socklen_t,
)
};
if rc != 0 {
return Err(std::io::Error::last_os_error());
}
Ok(())
}
/// Perform full ME RPC handshake on an established TCP stream.
/// Returns cipher keys/ivs and split halves; does not register writer.
pub(crate) async fn handshake_only(

View File

@@ -19,7 +19,7 @@ pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, _min_c
.read()
.await
.iter()
.map(|(a, _)| *a)
.map(|w| w.addr)
.collect();
for (dc, addrs) in map.iter() {

View File

@@ -1,11 +1,12 @@
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicI32, AtomicU64};
use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, Ordering};
use bytes::BytesMut;
use rand::Rng;
use rand::seq::SliceRandom;
use tokio::sync::{Mutex, RwLock};
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
use std::time::Duration;
@@ -14,15 +15,26 @@ use crate::error::{ProxyError, Result};
use crate::protocol::constants::*;
use super::ConnRegistry;
use super::registry::{BoundConn, ConnMeta};
use super::codec::RpcWriter;
use super::reader::reader_loop;
use super::MeResponse;
const ME_ACTIVE_PING_SECS: u64 = 25;
const ME_ACTIVE_PING_JITTER_SECS: i64 = 5;
#[derive(Clone)]
pub struct MeWriter {
pub id: u64,
pub addr: SocketAddr,
pub writer: Arc<Mutex<RpcWriter>>,
pub cancel: CancellationToken,
pub degraded: Arc<AtomicBool>,
}
pub struct MePool {
pub(super) registry: Arc<ConnRegistry>,
pub(super) writers: Arc<RwLock<Vec<(SocketAddr, Arc<Mutex<RpcWriter>>)>>> ,
pub(super) writers: Arc<RwLock<Vec<MeWriter>>>,
pub(super) rr: AtomicU64,
pub(super) proxy_tag: Option<Vec<u8>>,
pub(super) proxy_secret: Arc<RwLock<Vec<u8>>>,
@@ -33,6 +45,9 @@ pub struct MePool {
pub(super) proxy_map_v4: Arc<RwLock<HashMap<i32, Vec<(IpAddr, u16)>>>>,
pub(super) proxy_map_v6: Arc<RwLock<HashMap<i32, Vec<(IpAddr, u16)>>>>,
pub(super) default_dc: AtomicI32,
pub(super) next_writer_id: AtomicU64,
pub(super) ping_tracker: Arc<Mutex<HashMap<i64, (std::time::Instant, u64)>>>,
pub(super) rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
pool_size: usize,
}
@@ -61,6 +76,9 @@ impl MePool {
proxy_map_v4: Arc::new(RwLock::new(proxy_map_v4)),
proxy_map_v6: Arc::new(RwLock::new(proxy_map_v6)),
default_dc: AtomicI32::new(default_dc.unwrap_or(0)),
next_writer_id: AtomicU64::new(1),
ping_tracker: Arc::new(Mutex::new(HashMap::new())),
rtt_stats: Arc::new(Mutex::new(HashMap::new())),
})
}
@@ -77,16 +95,19 @@ impl MePool {
&self.registry
}
fn writers_arc(&self) -> Arc<RwLock<Vec<(SocketAddr, Arc<Mutex<RpcWriter>>)>>>
{
fn writers_arc(&self) -> Arc<RwLock<Vec<MeWriter>>> {
self.writers.clone()
}
pub async fn reconcile_connections(&self, rng: &SecureRandom) {
pub async fn reconcile_connections(self: &Arc<Self>, rng: &SecureRandom) {
use std::collections::HashSet;
let map = self.proxy_map_v4.read().await.clone();
let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map
.iter()
.map(|(dc, addrs)| (*dc, addrs.clone()))
.collect();
let writers = self.writers.read().await;
let current: HashSet<SocketAddr> = writers.iter().map(|(a, _)| *a).collect();
let current: HashSet<SocketAddr> = writers.iter().map(|w| w.addr).collect();
drop(writers);
for (_dc, addrs) in map.iter() {
@@ -158,8 +179,12 @@ impl MePool {
}
}
pub async fn init(self: &Arc<Self>, pool_size: usize, rng: &SecureRandom) -> Result<()> {
let map = self.proxy_map_v4.read().await;
pub async fn init(self: &Arc<Self>, pool_size: usize, rng: &Arc<SecureRandom>) -> Result<()> {
let map = self.proxy_map_v4.read().await.clone();
let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map
.iter()
.map(|(dc, addrs)| (*dc, addrs.clone()))
.collect();
let ks = self.key_selector().await;
info!(
me_servers = map.len(),
@@ -169,38 +194,28 @@ impl MePool {
"Initializing ME pool"
);
// Ensure at least one connection per DC with failover over all addresses
for (dc, addrs) in map.iter() {
// Ensure at least one connection per DC; run DCs in parallel.
let mut join = tokio::task::JoinSet::new();
for (dc, addrs) in dc_addrs.iter().cloned() {
if addrs.is_empty() {
continue;
}
let mut connected = false;
let mut shuffled = addrs.clone();
shuffled.shuffle(&mut rand::rng());
for (ip, port) in shuffled {
let addr = SocketAddr::new(ip, port);
match self.connect_one(addr, rng).await {
Ok(()) => {
info!(%addr, dc = %dc, "ME connected");
connected = true;
break;
}
Err(e) => warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next"),
}
}
if !connected {
warn!(dc = %dc, "All ME servers for DC failed at init");
}
let pool = Arc::clone(self);
let rng_clone = Arc::clone(rng);
join.spawn(async move {
pool.connect_primary_for_dc(dc, addrs, rng_clone).await;
});
}
while let Some(_res) = join.join_next().await {}
// Additional connections up to pool_size total (round-robin across DCs)
for (dc, addrs) in map.iter() {
for (dc, addrs) in dc_addrs.iter() {
for (ip, port) in addrs {
if self.connection_count() >= pool_size {
break;
}
let addr = SocketAddr::new(*ip, *port);
if let Err(e) = self.connect_one(addr, rng).await {
if let Err(e) = self.connect_one(addr, rng.as_ref()).await {
debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed");
}
}
@@ -215,7 +230,7 @@ impl MePool {
Ok(())
}
pub(crate) async fn connect_one(&self, addr: SocketAddr, rng: &SecureRandom) -> Result<()> {
pub(crate) async fn connect_one(self: &Arc<Self>, addr: SocketAddr, rng: &SecureRandom) -> Result<()> {
let secret_len = self.proxy_secret.read().await.len();
if secret_len < 32 {
return Err(ProxyError::Proxy("proxy-secret too short for ME auth".into()));
@@ -224,44 +239,88 @@ impl MePool {
let (stream, _connect_ms) = self.connect_tcp(addr).await?;
let hs = self.handshake_only(stream, addr, rng).await?;
let writer_id = self.next_writer_id.fetch_add(1, Ordering::Relaxed);
let cancel = CancellationToken::new();
let degraded = Arc::new(AtomicBool::new(false));
let rpc_w = Arc::new(Mutex::new(RpcWriter {
writer: hs.wr,
key: hs.write_key,
iv: hs.write_iv,
seq_no: 0,
}));
self.writers.write().await.push((addr, rpc_w.clone()));
let writer = MeWriter {
id: writer_id,
addr,
writer: rpc_w.clone(),
cancel: cancel.clone(),
degraded: degraded.clone(),
};
self.writers.write().await.push(writer.clone());
let reg = self.registry.clone();
let w_pong = rpc_w.clone();
let w_pool = self.writers_arc();
let w_ping = rpc_w.clone();
let w_pool_ping = self.writers_arc();
let writers_arc = self.writers_arc();
let ping_tracker = self.ping_tracker.clone();
let rtt_stats = self.rtt_stats.clone();
let pool = Arc::downgrade(self);
let cancel_ping = cancel.clone();
let rpc_w_ping = rpc_w.clone();
let ping_tracker_ping = ping_tracker.clone();
tokio::spawn(async move {
if let Err(e) =
reader_loop(hs.rd, hs.read_key, hs.read_iv, reg, BytesMut::new(), BytesMut::new(), w_pong.clone()).await
{
let cancel_reader = cancel.clone();
let res = reader_loop(
hs.rd,
hs.read_key,
hs.read_iv,
reg.clone(),
BytesMut::new(),
BytesMut::new(),
rpc_w.clone(),
ping_tracker.clone(),
rtt_stats.clone(),
writer_id,
degraded.clone(),
cancel_reader.clone(),
)
.await;
if let Some(pool) = pool.upgrade() {
pool.remove_writer_and_reroute(writer_id).await;
}
if let Err(e) = res {
warn!(error = %e, "ME reader ended");
}
let mut ws = w_pool.write().await;
ws.retain(|(_, w)| !Arc::ptr_eq(w, &w_pong));
let mut ws = writers_arc.write().await;
ws.retain(|w| w.id != writer_id);
info!(remaining = ws.len(), "Dead ME writer removed from pool");
});
let pool_ping = Arc::downgrade(self);
tokio::spawn(async move {
let mut ping_id: i64 = rand::random::<i64>();
loop {
let jitter = rand::rng()
.random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS);
let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64;
tokio::time::sleep(Duration::from_secs(wait)).await;
tokio::select! {
_ = cancel_ping.cancelled() => {
break;
}
_ = tokio::time::sleep(Duration::from_secs(wait)) => {}
}
let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_PING_U32.to_le_bytes());
p.extend_from_slice(&ping_id.to_le_bytes());
ping_id = ping_id.wrapping_add(1);
if let Err(e) = w_ping.lock().await.send(&p).await {
{
let mut tracker = ping_tracker_ping.lock().await;
tracker.insert(ping_id, (std::time::Instant::now(), writer_id));
}
if let Err(e) = rpc_w_ping.lock().await.send(&p).await {
debug!(error = %e, "Active ME ping failed, removing dead writer");
let mut ws = w_pool_ping.write().await;
ws.retain(|(_, w)| !Arc::ptr_eq(w, &w_ping));
cancel_ping.cancel();
if let Some(pool) = pool_ping.upgrade() {
pool.remove_writer_and_reroute(writer_id).await;
}
break;
}
}
@@ -270,6 +329,124 @@ impl MePool {
Ok(())
}
async fn connect_primary_for_dc(
self: Arc<Self>,
dc: i32,
mut addrs: Vec<(IpAddr, u16)>,
rng: Arc<SecureRandom>,
) {
if addrs.is_empty() {
return;
}
addrs.shuffle(&mut rand::rng());
for (ip, port) in addrs {
let addr = SocketAddr::new(ip, port);
match self.connect_one(addr, rng.as_ref()).await {
Ok(()) => {
info!(%addr, dc = %dc, "ME connected");
return;
}
Err(e) => warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next"),
}
}
warn!(dc = %dc, "All ME servers for DC failed at init");
}
pub(crate) async fn remove_writer_and_reroute(&self, writer_id: u64) {
let mut queue = self.remove_writer_only(writer_id).await;
while let Some(bound) = queue.pop() {
if !self.reroute_conn(&bound, &mut queue).await {
let _ = self.registry.route(bound.conn_id, super::MeResponse::Close).await;
}
}
}
async fn remove_writer_only(&self, writer_id: u64) -> Vec<BoundConn> {
{
let mut ws = self.writers.write().await;
if let Some(pos) = ws.iter().position(|w| w.id == writer_id) {
let w = ws.remove(pos);
w.cancel.cancel();
}
}
self.registry.writer_lost(writer_id).await
}
async fn reroute_conn(&self, bound: &BoundConn, backlog: &mut Vec<BoundConn>) -> bool {
let payload = super::wire::build_proxy_req_payload(
bound.conn_id,
bound.meta.client_addr,
bound.meta.our_addr,
&[],
self.proxy_tag.as_deref(),
bound.meta.proto_flags,
);
let mut attempts = 0;
loop {
let writers_snapshot = {
let ws = self.writers.read().await;
if ws.is_empty() {
return false;
}
ws.clone()
};
let mut candidates = self.candidate_indices_for_dc(&writers_snapshot, bound.meta.target_dc).await;
if candidates.is_empty() {
return false;
}
candidates.sort_by_key(|idx| {
writers_snapshot[*idx]
.degraded
.load(Ordering::Relaxed)
.then_some(1usize)
.unwrap_or(0)
});
let start = self.rr.fetch_add(1, Ordering::Relaxed) as usize % candidates.len();
for offset in 0..candidates.len() {
let idx = candidates[(start + offset) % candidates.len()];
let w = &writers_snapshot[idx];
if let Ok(mut guard) = w.writer.try_lock() {
let send_res = guard.send(&payload).await;
drop(guard);
match send_res {
Ok(()) => {
self.registry
.bind_writer(bound.conn_id, w.id, w.writer.clone(), bound.meta.clone())
.await;
return true;
}
Err(e) => {
warn!(error = %e, writer_id = w.id, "ME reroute send failed");
backlog.extend(self.remove_writer_only(w.id).await);
}
}
continue;
}
}
let w = writers_snapshot[candidates[start]].clone();
match w.writer.lock().await.send(&payload).await {
Ok(()) => {
self.registry
.bind_writer(bound.conn_id, w.id, w.writer.clone(), bound.meta.clone())
.await;
return true;
}
Err(e) => {
warn!(error = %e, writer_id = w.id, "ME reroute send failed (blocking)");
backlog.extend(self.remove_writer_only(w.id).await);
}
}
attempts += 1;
if attempts > 3 {
return false;
}
}
}
}
fn hex_dump(data: &[u8]) -> String {

View File

@@ -1,9 +1,13 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Instant;
use bytes::{Bytes, BytesMut};
use tokio::io::AsyncReadExt;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;
use tracing::{debug, trace, warn};
use crate::crypto::{AesCbc, crc32};
@@ -21,12 +25,20 @@ pub(crate) async fn reader_loop(
enc_leftover: BytesMut,
mut dec: BytesMut,
writer: Arc<Mutex<RpcWriter>>,
ping_tracker: Arc<Mutex<HashMap<i64, (Instant, u64)>>>,
rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
_writer_id: u64,
degraded: Arc<AtomicBool>,
cancel: CancellationToken,
) -> Result<()> {
let mut raw = enc_leftover;
loop {
let mut tmp = [0u8; 16_384];
let n = rd.read(&mut tmp).await.map_err(ProxyError::Io)?;
let n = tokio::select! {
res = rd.read(&mut tmp) => res.map_err(ProxyError::Io)?,
_ = cancel.cancelled() => return Ok(()),
};
if n == 0 {
return Ok(());
}
@@ -119,6 +131,23 @@ pub(crate) async fn reader_loop(
warn!(error = %e, "PONG send failed");
break;
}
} else if pt == RPC_PONG_U32 && body.len() >= 8 {
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
if let Some((sent, wid)) = {
let mut guard = ping_tracker.lock().await;
guard.remove(&ping_id)
} {
let rtt = sent.elapsed().as_secs_f64() * 1000.0;
let mut stats = rtt_stats.lock().await;
let entry = stats.entry(wid).or_insert((rtt, rtt));
entry.1 = entry.1 * 0.8 + rtt * 0.2;
if rtt < entry.0 {
entry.0 = rtt;
}
let degraded_now = entry.1 > entry.0 * 2.0;
degraded.store(degraded_now, Ordering::Relaxed);
trace!(writer_id = wid, rtt_ms = rtt, ema_ms = entry.1, base_ms = entry.0, degraded = degraded_now, "ME RTT sample");
}
} else {
debug!(
rpc_type = format_args!("0x{pt:08x}"),

View File

@@ -1,34 +1,57 @@
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::{RwLock, mpsc};
use super::MeResponse;
use super::codec::RpcWriter;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::{mpsc, Mutex, RwLock};
use super::codec::RpcWriter;
use super::MeResponse;
#[derive(Clone)]
pub struct ConnMeta {
pub target_dc: i16,
pub client_addr: SocketAddr,
pub our_addr: SocketAddr,
pub proto_flags: u32,
}
#[derive(Clone)]
pub struct BoundConn {
pub conn_id: u64,
pub meta: ConnMeta,
}
#[derive(Clone)]
pub struct ConnWriter {
pub writer_id: u64,
pub writer: Arc<Mutex<RpcWriter>>,
}
pub struct ConnRegistry {
map: RwLock<HashMap<u64, mpsc::UnboundedSender<MeResponse>>>,
writers: RwLock<HashMap<u64, Arc<Mutex<RpcWriter>>>>,
writer_for_conn: RwLock<HashMap<u64, u64>>,
conns_for_writer: RwLock<HashMap<u64, Vec<u64>>>,
meta: RwLock<HashMap<u64, ConnMeta>>,
next_id: AtomicU64,
}
impl ConnRegistry {
pub fn new() -> Self {
// Avoid fully predictable conn_id sequence from 1.
let start = rand::random::<u64>() | 1;
Self {
map: RwLock::new(HashMap::new()),
writers: RwLock::new(HashMap::new()),
writer_for_conn: RwLock::new(HashMap::new()),
conns_for_writer: RwLock::new(HashMap::new()),
meta: RwLock::new(HashMap::new()),
next_id: AtomicU64::new(start),
}
}
pub async fn register(&self) -> (u64, mpsc::UnboundedReceiver<MeResponse>) {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
// Unbounded per-connection queue prevents reader-loop HOL blocking on
// slow clients: routing stays non-blocking and preserves message order.
let (tx, rx) = mpsc::unbounded_channel();
self.map.write().await.insert(id, tx);
(id, rx)
@@ -36,7 +59,12 @@ impl ConnRegistry {
pub async fn unregister(&self, id: u64) {
self.map.write().await.remove(&id);
self.writers.write().await.remove(&id);
self.meta.write().await.remove(&id);
if let Some(writer_id) = self.writer_for_conn.write().await.remove(&id) {
if let Some(list) = self.conns_for_writer.write().await.get_mut(&writer_id) {
list.retain(|c| *c != id);
}
}
}
pub async fn route(&self, id: u64, resp: MeResponse) -> bool {
@@ -48,13 +76,58 @@ impl ConnRegistry {
}
}
pub async fn set_writer(&self, id: u64, w: Arc<Mutex<RpcWriter>>) {
let mut guard = self.writers.write().await;
guard.entry(id).or_insert_with(|| w);
pub async fn bind_writer(
&self,
conn_id: u64,
writer_id: u64,
writer: Arc<Mutex<RpcWriter>>,
meta: ConnMeta,
) {
self.meta.write().await.entry(conn_id).or_insert(meta);
self.writer_for_conn.write().await.insert(conn_id, writer_id);
self.writers.write().await.entry(writer_id).or_insert_with(|| writer.clone());
self.conns_for_writer
.write()
.await
.entry(writer_id)
.or_insert_with(Vec::new)
.push(conn_id);
}
pub async fn get_writer(&self, id: u64) -> Option<Arc<Mutex<RpcWriter>>> {
pub async fn get_writer(&self, conn_id: u64) -> Option<ConnWriter> {
let writer_id = {
let guard = self.writer_for_conn.read().await;
guard.get(&conn_id).cloned()
}?;
let writer = {
let guard = self.writers.read().await;
guard.get(&id).cloned()
guard.get(&writer_id).cloned()
}?;
Some(ConnWriter { writer_id, writer })
}
pub async fn writer_lost(&self, writer_id: u64) -> Vec<BoundConn> {
self.writers.write().await.remove(&writer_id);
let conns = self.conns_for_writer.write().await.remove(&writer_id).unwrap_or_default();
let mut out = Vec::new();
let mut writer_for_conn = self.writer_for_conn.write().await;
let meta = self.meta.read().await;
for conn_id in conns {
writer_for_conn.remove(&conn_id);
if let Some(m) = meta.get(&conn_id) {
out.push(BoundConn {
conn_id,
meta: m.clone(),
});
}
}
out
}
pub async fn get_meta(&self, conn_id: u64) -> Option<ConnMeta> {
let guard = self.meta.read().await;
guard.get(&conn_id).cloned()
}
}

View File

@@ -9,14 +9,14 @@ use crate::error::{ProxyError, Result};
use crate::protocol::constants::RPC_CLOSE_EXT_U32;
use super::MePool;
use super::codec::RpcWriter;
use super::wire::build_proxy_req_payload;
use crate::crypto::SecureRandom;
use rand::seq::SliceRandom;
use super::registry::ConnMeta;
impl MePool {
pub async fn send_proxy_req(
&self,
self: &Arc<Self>,
conn_id: u64,
target_dc: i16,
client_addr: SocketAddr,
@@ -32,18 +32,45 @@ impl MePool {
self.proxy_tag.as_deref(),
proto_flags,
);
let meta = ConnMeta {
target_dc,
client_addr,
our_addr,
proto_flags,
};
loop {
if let Some(current) = self.registry.get_writer(conn_id).await {
let send_res = {
if let Ok(mut guard) = current.writer.try_lock() {
let r = guard.send(&payload).await;
drop(guard);
r
} else {
current.writer.lock().await.send(&payload).await
}
};
match send_res {
Ok(()) => return Ok(()),
Err(e) => {
warn!(error = %e, writer_id = current.writer_id, "ME write failed");
self.remove_writer_and_reroute(current.writer_id).await;
continue;
}
}
}
let mut writers_snapshot = {
let ws = self.writers.read().await;
if ws.is_empty() {
return Err(ProxyError::Proxy("All ME connections dead".into()));
}
let writers: Vec<(SocketAddr, Arc<Mutex<RpcWriter>>)> = ws.iter().cloned().collect();
drop(ws);
ws.clone()
};
let mut candidate_indices = self.candidate_indices_for_dc(&writers, target_dc).await;
let mut candidate_indices = self.candidate_indices_for_dc(&writers_snapshot, target_dc).await;
if candidate_indices.is_empty() {
// Emergency: try to connect to target DC addresses on the fly, then recompute writers
// Emergency connect-on-demand
let map = self.proxy_map_v4.read().await;
if let Some(addrs) = map.get(&(target_dc as i32)) {
let mut shuffled = addrs.clone();
@@ -56,64 +83,71 @@ impl MePool {
}
}
let ws2 = self.writers.read().await;
let writers: Vec<(SocketAddr, Arc<Mutex<RpcWriter>>)> = ws2.iter().cloned().collect();
writers_snapshot = ws2.clone();
drop(ws2);
candidate_indices = self.candidate_indices_for_dc(&writers, target_dc).await;
candidate_indices = self.candidate_indices_for_dc(&writers_snapshot, target_dc).await;
}
if candidate_indices.is_empty() {
return Err(ProxyError::Proxy("No ME writers available for target DC".into()));
}
}
candidate_indices.sort_by_key(|idx| {
writers_snapshot[*idx]
.degraded
.load(Ordering::Relaxed)
.then_some(1usize)
.unwrap_or(0)
});
let start = self.rr.fetch_add(1, Ordering::Relaxed) as usize % candidate_indices.len();
// Prefer immediately available writer to avoid waiting on stalled connection.
for offset in 0..candidate_indices.len() {
let cidx = (start + offset) % candidate_indices.len();
let idx = candidate_indices[cidx];
let w = writers[idx].1.clone();
if let Ok(mut guard) = w.try_lock() {
let idx = candidate_indices[(start + offset) % candidate_indices.len()];
let w = &writers_snapshot[idx];
if let Ok(mut guard) = w.writer.try_lock() {
let send_res = guard.send(&payload).await;
drop(guard);
match send_res {
Ok(()) => return Ok(()),
Err(e) => {
warn!(error = %e, "ME write failed, removing dead conn");
let mut ws = self.writers.write().await;
ws.retain(|(_, o)| !Arc::ptr_eq(o, &w));
if ws.is_empty() {
return Err(ProxyError::Proxy("All ME connections dead".into()));
Ok(()) => {
self.registry
.bind_writer(conn_id, w.id, w.writer.clone(), meta.clone())
.await;
return Ok(());
}
Err(e) => {
warn!(error = %e, writer_id = w.id, "ME write failed");
self.remove_writer_and_reroute(w.id).await;
continue;
}
}
}
}
// All writers are currently busy, wait for the selected one.
let w = writers[candidate_indices[start]].1.clone();
match w.lock().await.send(&payload).await {
Ok(()) => return Ok(()),
Err(e) => {
warn!(error = %e, "ME write failed, removing dead conn");
let mut ws = self.writers.write().await;
ws.retain(|(_, o)| !Arc::ptr_eq(o, &w));
if ws.is_empty() {
return Err(ProxyError::Proxy("All ME connections dead".into()));
let w = writers_snapshot[candidate_indices[start]].clone();
match w.writer.lock().await.send(&payload).await {
Ok(()) => {
self.registry
.bind_writer(conn_id, w.id, w.writer.clone(), meta.clone())
.await;
return Ok(());
}
Err(e) => {
warn!(error = %e, writer_id = w.id, "ME write failed (blocking)");
self.remove_writer_and_reroute(w.id).await;
}
}
}
}
pub async fn send_close(&self, conn_id: u64) -> Result<()> {
pub async fn send_close(self: &Arc<Self>, conn_id: u64) -> Result<()> {
if let Some(w) = self.registry.get_writer(conn_id).await {
let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes());
p.extend_from_slice(&conn_id.to_le_bytes());
if let Err(e) = w.lock().await.send(&p).await {
if let Err(e) = w.writer.lock().await.send(&p).await {
debug!(error = %e, "ME close write failed");
let mut ws = self.writers.write().await;
ws.retain(|(_, o)| !Arc::ptr_eq(o, &w));
self.remove_writer_and_reroute(w.writer_id).await;
}
} else {
debug!(conn_id, "ME close skipped (writer missing)");
@@ -129,7 +163,7 @@ impl MePool {
pub(super) async fn candidate_indices_for_dc(
&self,
writers: &[(SocketAddr, Arc<Mutex<RpcWriter>>)],
writers: &[super::pool::MeWriter],
target_dc: i16,
) -> Vec<usize> {
let mut preferred = Vec::<SocketAddr>::new();
@@ -165,8 +199,8 @@ impl MePool {
}
let mut out = Vec::new();
for (idx, (addr, _)) in writers.iter().enumerate() {
if preferred.iter().any(|p| p == addr) {
for (idx, w) in writers.iter().enumerate() {
if preferred.iter().any(|p| *p == w.addr) {
out.push(idx);
}
}