New Relay on Tokio Copy Bidirectional
This commit is contained in:
@@ -1,195 +1,466 @@
|
||||
//! Bidirectional Relay
|
||||
//! Bidirectional Relay — poll-based, no head-of-line blocking
|
||||
//!
|
||||
//! ## What changed and why
|
||||
//!
|
||||
//! Previous implementation used a single-task `select! { biased; ... }` loop
|
||||
//! where each branch called `write_all()`. This caused head-of-line blocking:
|
||||
//! while `write_all()` waited for a slow writer (e.g. client on 3G downloading
|
||||
//! media), the entire loop was blocked — the other direction couldn't make progress.
|
||||
//!
|
||||
//! Symptoms observed in production:
|
||||
//! - Media loading at ~8 KB/s despite fast server connection
|
||||
//! - Stop-and-go pattern with 50–500ms gaps between chunks
|
||||
//! - `biased` select starving S→C direction
|
||||
//! - Some users unable to load media at all
|
||||
//!
|
||||
//! ## New architecture
|
||||
//!
|
||||
//! Uses `tokio::io::copy_bidirectional` which polls both directions concurrently
|
||||
//! in a single task via non-blocking `poll_read` / `poll_write` calls:
|
||||
//!
|
||||
//! Old (select! + write_all — BLOCKING):
|
||||
//!
|
||||
//! loop {
|
||||
//! select! {
|
||||
//! biased;
|
||||
//! data = client.read() => { server.write_all(data).await; } ← BLOCKS here
|
||||
//! data = server.read() => { client.write_all(data).await; } ← can't run
|
||||
//! }
|
||||
//! }
|
||||
//!
|
||||
//! New (copy_bidirectional — CONCURRENT):
|
||||
//!
|
||||
//! poll(cx) {
|
||||
//! // Both directions polled in the same poll cycle
|
||||
//! C→S: poll_read(client) → poll_write(server) // non-blocking
|
||||
//! S→C: poll_read(server) → poll_write(client) // non-blocking
|
||||
//! // If one writer is Pending, the other direction still progresses
|
||||
//! }
|
||||
//!
|
||||
//! Benefits:
|
||||
//! - No head-of-line blocking: slow client download doesn't block uploads
|
||||
//! - No biased starvation: fair polling of both directions
|
||||
//! - Proper flush: `copy_bidirectional` calls `poll_flush` when reader stalls,
|
||||
//! so CryptoWriter's pending ciphertext is always drained (fixes "stuck at 95%")
|
||||
//! - No deadlock risk: old write_all could deadlock when both TCP buffers filled;
|
||||
//! poll-based approach lets TCP flow control work correctly
|
||||
//!
|
||||
//! Stats tracking:
|
||||
//! - `StatsIo` wraps client side, intercepts `poll_read` / `poll_write`
|
||||
//! - `poll_read` on client = C→S (client sending) → `octets_from`, `msgs_from`
|
||||
//! - `poll_write` on client = S→C (to client) → `octets_to`, `msgs_to`
|
||||
//! - `SharedCounters` (atomics) let the watchdog read stats without locking
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{debug, trace, warn};
|
||||
use crate::error::Result;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::BufferPool;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{debug, trace, warn};
|
||||
use crate::error::Result;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::BufferPool;
|
||||
|
||||
// Activity timeout for iOS compatibility (30 minutes)
|
||||
const ACTIVITY_TIMEOUT_SECS: u64 = 1800;
|
||||
// ============= Constants =============
|
||||
|
||||
/// Relay data bidirectionally between client and server.
|
||||
///
|
||||
/// Uses a single-task select!-based loop instead of spawning two tasks.
|
||||
/// This eliminates:
|
||||
/// - 2× task spawn overhead per connection
|
||||
/// - Zombie task problem (old code used select! on JoinHandles but
|
||||
/// never aborted the losing task — it would run for up to 30 min)
|
||||
/// - Extra Arc<AtomicU64> allocations for cross-task byte counters
|
||||
///
|
||||
/// The flush()-per-write was also removed: TCP_NODELAY is set on all
|
||||
/// sockets (socket.rs), so data is pushed immediately without Nagle
|
||||
/// buffering. Explicit flush() on every small read was causing a
|
||||
/// syscall storm and defeating CryptoWriter's internal coalescing.
|
||||
pub async fn relay_bidirectional<CR, CW, SR, SW>(
|
||||
mut client_reader: CR,
|
||||
mut client_writer: CW,
|
||||
mut server_reader: SR,
|
||||
mut server_writer: SW,
|
||||
user: &str,
|
||||
stats: Arc<Stats>,
|
||||
buffer_pool: Arc<BufferPool>,
|
||||
) -> Result<()>
|
||||
where
|
||||
CR: AsyncRead + Unpin + Send + 'static,
|
||||
CW: AsyncWrite + Unpin + Send + 'static,
|
||||
SR: AsyncRead + Unpin + Send + 'static,
|
||||
SW: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
// Get buffers from pool — one per direction
|
||||
let mut c2s_buf = buffer_pool.get();
|
||||
let cap = c2s_buf.capacity();
|
||||
c2s_buf.resize(cap, 0);
|
||||
/// Activity timeout for iOS compatibility.
|
||||
///
|
||||
/// iOS keeps Telegram connections alive in background for up to 30 minutes.
|
||||
/// Closing earlier causes unnecessary reconnects and handshake overhead.
|
||||
const ACTIVITY_TIMEOUT: Duration = Duration::from_secs(1800);
|
||||
|
||||
let mut s2c_buf = buffer_pool.get();
|
||||
let cap = s2c_buf.capacity();
|
||||
s2c_buf.resize(cap, 0);
|
||||
/// Watchdog check interval — also used for periodic rate logging.
|
||||
///
|
||||
/// 10 seconds gives responsive timeout detection (±10s accuracy)
|
||||
/// without measurable overhead from atomic reads.
|
||||
const WATCHDOG_INTERVAL: Duration = Duration::from_secs(10);
|
||||
|
||||
let activity_timeout = Duration::from_secs(ACTIVITY_TIMEOUT_SECS);
|
||||
// ============= CombinedStream =============
|
||||
|
||||
let mut c2s_total: u64 = 0;
|
||||
let mut s2c_total: u64 = 0;
|
||||
let mut c2s_msgs: u64 = 0;
|
||||
let mut s2c_msgs: u64 = 0;
|
||||
/// Combines separate read and write halves into a single bidirectional stream.
|
||||
///
|
||||
/// `copy_bidirectional` requires `AsyncRead + AsyncWrite` on each side,
|
||||
/// but the handshake layer produces split reader/writer pairs
|
||||
/// (e.g. `CryptoReader<FakeTlsReader<OwnedReadHalf>>` + `CryptoWriter<...>`).
|
||||
///
|
||||
/// This wrapper reunifies them with zero overhead — each trait method
|
||||
/// delegates directly to the corresponding half. No buffering, no copies.
|
||||
///
|
||||
/// Safety: `poll_read` only touches `reader`, `poll_write` only touches `writer`,
|
||||
/// so there's no aliasing even though both are called on the same `&mut self`.
|
||||
struct CombinedStream<R, W> {
|
||||
reader: R,
|
||||
writer: W,
|
||||
}
|
||||
|
||||
// For periodic rate logging
|
||||
let mut c2s_prev: u64 = 0;
|
||||
let mut s2c_prev: u64 = 0;
|
||||
let mut last_log = Instant::now();
|
||||
impl<R, W> CombinedStream<R, W> {
|
||||
fn new(reader: R, writer: W) -> Self {
|
||||
Self { reader, writer }
|
||||
}
|
||||
}
|
||||
|
||||
let user_owned = user.to_string();
|
||||
impl<R: AsyncRead + Unpin, W: Unpin> AsyncRead for CombinedStream<R, W> {
|
||||
#[inline]
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.get_mut().reader).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
|
||||
// Client -> Server direction
|
||||
result = tokio::time::timeout(activity_timeout, client_reader.read(&mut c2s_buf)) => {
|
||||
match result {
|
||||
Err(_) => {
|
||||
// Activity timeout
|
||||
warn!(
|
||||
user = %user_owned,
|
||||
c2s_bytes = c2s_total,
|
||||
s2c_bytes = s2c_total,
|
||||
"Activity timeout (C->S)"
|
||||
);
|
||||
break;
|
||||
}
|
||||
Ok(Ok(0)) => {
|
||||
// Client closed
|
||||
debug!(
|
||||
user = %user_owned,
|
||||
c2s_bytes = c2s_total,
|
||||
s2c_bytes = s2c_total,
|
||||
"Client closed connection"
|
||||
);
|
||||
break;
|
||||
}
|
||||
Ok(Ok(n)) => {
|
||||
c2s_total += n as u64;
|
||||
c2s_msgs += 1;
|
||||
|
||||
stats.add_user_octets_from(&user_owned, n as u64);
|
||||
stats.increment_user_msgs_from(&user_owned);
|
||||
|
||||
trace!(user = %user_owned, bytes = n, "C->S");
|
||||
|
||||
// Write without flush — TCP_NODELAY handles push
|
||||
if let Err(e) = server_writer.write_all(&c2s_buf[..n]).await {
|
||||
debug!(user = %user_owned, error = %e, "Write to server failed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
debug!(user = %user_owned, error = %e, "Client read error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Server -> Client direction
|
||||
result = tokio::time::timeout(activity_timeout, server_reader.read(&mut s2c_buf)) => {
|
||||
match result {
|
||||
Err(_) => {
|
||||
warn!(
|
||||
user = %user_owned,
|
||||
c2s_bytes = c2s_total,
|
||||
s2c_bytes = s2c_total,
|
||||
"Activity timeout (S->C)"
|
||||
);
|
||||
break;
|
||||
}
|
||||
Ok(Ok(0)) => {
|
||||
debug!(
|
||||
user = %user_owned,
|
||||
c2s_bytes = c2s_total,
|
||||
s2c_bytes = s2c_total,
|
||||
"Server closed connection"
|
||||
);
|
||||
break;
|
||||
}
|
||||
Ok(Ok(n)) => {
|
||||
s2c_total += n as u64;
|
||||
s2c_msgs += 1;
|
||||
|
||||
stats.add_user_octets_to(&user_owned, n as u64);
|
||||
stats.increment_user_msgs_to(&user_owned);
|
||||
|
||||
trace!(user = %user_owned, bytes = n, "S->C");
|
||||
|
||||
if let Err(e) = client_writer.write_all(&s2c_buf[..n]).await {
|
||||
debug!(user = %user_owned, error = %e, "Write to client failed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
debug!(user = %user_owned, error = %e, "Server read error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Periodic rate logging (every 10s)
|
||||
let elapsed = last_log.elapsed();
|
||||
if elapsed > Duration::from_secs(10) {
|
||||
let secs = elapsed.as_secs_f64();
|
||||
let c2s_delta = c2s_total - c2s_prev;
|
||||
let s2c_delta = s2c_total - s2c_prev;
|
||||
|
||||
debug!(
|
||||
user = %user_owned,
|
||||
c2s_kbps = (c2s_delta as f64 / secs / 1024.0) as u64,
|
||||
s2c_kbps = (s2c_delta as f64 / secs / 1024.0) as u64,
|
||||
c2s_total = c2s_total,
|
||||
s2c_total = s2c_total,
|
||||
"Relay active"
|
||||
);
|
||||
|
||||
c2s_prev = c2s_total;
|
||||
s2c_prev = s2c_total;
|
||||
last_log = Instant::now();
|
||||
}
|
||||
}
|
||||
|
||||
// Clean shutdown of both directions
|
||||
let _ = server_writer.shutdown().await;
|
||||
let _ = client_writer.shutdown().await;
|
||||
|
||||
debug!(
|
||||
user = %user_owned,
|
||||
c2s_bytes = c2s_total,
|
||||
s2c_bytes = s2c_total,
|
||||
c2s_msgs = c2s_msgs,
|
||||
s2c_msgs = s2c_msgs,
|
||||
"Relay finished"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
impl<R: Unpin, W: AsyncWrite + Unpin> AsyncWrite for CombinedStream<R, W> {
|
||||
#[inline]
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.get_mut().writer).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.get_mut().writer).poll_flush(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.get_mut().writer).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
// ============= SharedCounters =============
|
||||
|
||||
/// Atomic counters shared between the relay (via StatsIo) and the watchdog task.
|
||||
///
|
||||
/// Using `Relaxed` ordering is sufficient because:
|
||||
/// - Counters are monotonically increasing (no ABA problem)
|
||||
/// - Slight staleness in watchdog reads is harmless (±10s check interval anyway)
|
||||
/// - No ordering dependencies between different counters
|
||||
struct SharedCounters {
|
||||
/// Bytes read from client (C→S direction)
|
||||
c2s_bytes: AtomicU64,
|
||||
/// Bytes written to client (S→C direction)
|
||||
s2c_bytes: AtomicU64,
|
||||
/// Number of poll_read completions (≈ C→S chunks)
|
||||
c2s_ops: AtomicU64,
|
||||
/// Number of poll_write completions (≈ S→C chunks)
|
||||
s2c_ops: AtomicU64,
|
||||
/// Milliseconds since relay epoch of last I/O activity
|
||||
last_activity_ms: AtomicU64,
|
||||
}
|
||||
|
||||
impl SharedCounters {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
c2s_bytes: AtomicU64::new(0),
|
||||
s2c_bytes: AtomicU64::new(0),
|
||||
c2s_ops: AtomicU64::new(0),
|
||||
s2c_ops: AtomicU64::new(0),
|
||||
last_activity_ms: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Record activity at this instant.
|
||||
#[inline]
|
||||
fn touch(&self, now: Instant, epoch: Instant) {
|
||||
let ms = now.duration_since(epoch).as_millis() as u64;
|
||||
self.last_activity_ms.store(ms, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/// How long since last recorded activity.
|
||||
fn idle_duration(&self, now: Instant, epoch: Instant) -> Duration {
|
||||
let last_ms = self.last_activity_ms.load(Ordering::Relaxed);
|
||||
let now_ms = now.duration_since(epoch).as_millis() as u64;
|
||||
Duration::from_millis(now_ms.saturating_sub(last_ms))
|
||||
}
|
||||
}
|
||||
|
||||
// ============= StatsIo =============
|
||||
|
||||
/// Transparent I/O wrapper that tracks per-user statistics and activity.
|
||||
///
|
||||
/// Wraps the **client** side of the relay. Direction mapping:
|
||||
///
|
||||
/// | poll method | direction | stats updated |
|
||||
/// |-------------|-----------|--------------------------------------|
|
||||
/// | `poll_read` | C→S | `octets_from`, `msgs_from`, counters |
|
||||
/// | `poll_write` | S→C | `octets_to`, `msgs_to`, counters |
|
||||
///
|
||||
/// Both update the shared activity timestamp for the watchdog.
|
||||
///
|
||||
/// Note on message counts: the original code counted one `read()`/`write_all()`
|
||||
/// as one "message". Here we count `poll_read`/`poll_write` completions instead.
|
||||
/// Byte counts are identical; op counts may differ slightly due to different
|
||||
/// internal buffering in `copy_bidirectional`. This is fine for monitoring.
|
||||
struct StatsIo<S> {
|
||||
inner: S,
|
||||
counters: Arc<SharedCounters>,
|
||||
stats: Arc<Stats>,
|
||||
user: String,
|
||||
epoch: Instant,
|
||||
}
|
||||
|
||||
impl<S> StatsIo<S> {
|
||||
fn new(
|
||||
inner: S,
|
||||
counters: Arc<SharedCounters>,
|
||||
stats: Arc<Stats>,
|
||||
user: String,
|
||||
epoch: Instant,
|
||||
) -> Self {
|
||||
// Mark initial activity so the watchdog doesn't fire before data flows
|
||||
counters.touch(Instant::now(), epoch);
|
||||
Self { inner, counters, stats, user, epoch }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
let this = self.get_mut();
|
||||
let before = buf.filled().len();
|
||||
|
||||
match Pin::new(&mut this.inner).poll_read(cx, buf) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
let n = buf.filled().len() - before;
|
||||
if n > 0 {
|
||||
// C→S: client sent data
|
||||
this.counters.c2s_bytes.fetch_add(n as u64, Ordering::Relaxed);
|
||||
this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed);
|
||||
this.counters.touch(Instant::now(), this.epoch);
|
||||
|
||||
this.stats.add_user_octets_from(&this.user, n as u64);
|
||||
this.stats.increment_user_msgs_from(&this.user);
|
||||
|
||||
trace!(user = %this.user, bytes = n, "C->S");
|
||||
}
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
match Pin::new(&mut this.inner).poll_write(cx, buf) {
|
||||
Poll::Ready(Ok(n)) => {
|
||||
if n > 0 {
|
||||
// S→C: data written to client
|
||||
this.counters.s2c_bytes.fetch_add(n as u64, Ordering::Relaxed);
|
||||
this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed);
|
||||
this.counters.touch(Instant::now(), this.epoch);
|
||||
|
||||
this.stats.add_user_octets_to(&this.user, n as u64);
|
||||
this.stats.increment_user_msgs_to(&this.user);
|
||||
|
||||
trace!(user = %this.user, bytes = n, "S->C");
|
||||
}
|
||||
Poll::Ready(Ok(n))
|
||||
}
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.get_mut().inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Relay =============
|
||||
|
||||
/// Relay data bidirectionally between client and server.
|
||||
///
|
||||
/// Uses `tokio::io::copy_bidirectional` for concurrent, non-blocking data transfer.
|
||||
///
|
||||
/// ## API compatibility
|
||||
///
|
||||
/// Signature is identical to the previous implementation. The `_buffer_pool`
|
||||
/// parameter is retained for call-site compatibility — `copy_bidirectional`
|
||||
/// manages its own internal buffers (8 KB per direction).
|
||||
///
|
||||
/// ## Guarantees preserved
|
||||
///
|
||||
/// - Activity timeout: 30 minutes of inactivity → clean shutdown
|
||||
/// - Per-user stats: bytes and ops counted per direction
|
||||
/// - Periodic rate logging: every 10 seconds when active
|
||||
/// - Clean shutdown: both write sides are shut down on exit
|
||||
/// - Error propagation: I/O errors are returned as `ProxyError::Io`
|
||||
pub async fn relay_bidirectional<CR, CW, SR, SW>(
|
||||
client_reader: CR,
|
||||
client_writer: CW,
|
||||
server_reader: SR,
|
||||
server_writer: SW,
|
||||
user: &str,
|
||||
stats: Arc<Stats>,
|
||||
_buffer_pool: Arc<BufferPool>,
|
||||
) -> Result<()>
|
||||
where
|
||||
CR: AsyncRead + Unpin + Send + 'static,
|
||||
CW: AsyncWrite + Unpin + Send + 'static,
|
||||
SR: AsyncRead + Unpin + Send + 'static,
|
||||
SW: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let epoch = Instant::now();
|
||||
let counters = Arc::new(SharedCounters::new());
|
||||
let user_owned = user.to_string();
|
||||
|
||||
// ── Combine split halves into bidirectional streams ──────────────
|
||||
let client_combined = CombinedStream::new(client_reader, client_writer);
|
||||
let mut server = CombinedStream::new(server_reader, server_writer);
|
||||
|
||||
// Wrap client with stats/activity tracking
|
||||
let mut client = StatsIo::new(
|
||||
client_combined,
|
||||
Arc::clone(&counters),
|
||||
Arc::clone(&stats),
|
||||
user_owned.clone(),
|
||||
epoch,
|
||||
);
|
||||
|
||||
// ── Watchdog: activity timeout + periodic rate logging ──────────
|
||||
let wd_counters = Arc::clone(&counters);
|
||||
let wd_user = user_owned.clone();
|
||||
|
||||
let watchdog = async {
|
||||
let mut prev_c2s: u64 = 0;
|
||||
let mut prev_s2c: u64 = 0;
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(WATCHDOG_INTERVAL).await;
|
||||
|
||||
let now = Instant::now();
|
||||
let idle = wd_counters.idle_duration(now, epoch);
|
||||
|
||||
// ── Activity timeout ────────────────────────────────────
|
||||
if idle >= ACTIVITY_TIMEOUT {
|
||||
let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed);
|
||||
let s2c = wd_counters.s2c_bytes.load(Ordering::Relaxed);
|
||||
warn!(
|
||||
user = %wd_user,
|
||||
c2s_bytes = c2s,
|
||||
s2c_bytes = s2c,
|
||||
idle_secs = idle.as_secs(),
|
||||
"Activity timeout"
|
||||
);
|
||||
return; // Causes select! to cancel copy_bidirectional
|
||||
}
|
||||
|
||||
// ── Periodic rate logging ───────────────────────────────
|
||||
let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed);
|
||||
let s2c = wd_counters.s2c_bytes.load(Ordering::Relaxed);
|
||||
let c2s_delta = c2s - prev_c2s;
|
||||
let s2c_delta = s2c - prev_s2c;
|
||||
|
||||
if c2s_delta > 0 || s2c_delta > 0 {
|
||||
let secs = WATCHDOG_INTERVAL.as_secs_f64();
|
||||
debug!(
|
||||
user = %wd_user,
|
||||
c2s_kbps = (c2s_delta as f64 / secs / 1024.0) as u64,
|
||||
s2c_kbps = (s2c_delta as f64 / secs / 1024.0) as u64,
|
||||
c2s_total = c2s,
|
||||
s2c_total = s2c,
|
||||
"Relay active"
|
||||
);
|
||||
}
|
||||
|
||||
prev_c2s = c2s;
|
||||
prev_s2c = s2c;
|
||||
}
|
||||
};
|
||||
|
||||
// ── Run bidirectional copy + watchdog concurrently ───────────────
|
||||
//
|
||||
// copy_bidirectional polls both directions in the same poll() call:
|
||||
// C→S: poll_read(client/StatsIo) → poll_write(server)
|
||||
// S→C: poll_read(server) → poll_write(client/StatsIo)
|
||||
//
|
||||
// When one direction's writer returns Pending, the other direction
|
||||
// continues — no head-of-line blocking.
|
||||
//
|
||||
// When the watchdog fires, select! drops the copy future,
|
||||
// releasing the &mut borrows on client and server.
|
||||
let copy_result = tokio::select! {
|
||||
result = copy_bidirectional(&mut client, &mut server) => Some(result),
|
||||
_ = watchdog => None, // Activity timeout — cancel relay
|
||||
};
|
||||
|
||||
// ── Clean shutdown ──────────────────────────────────────────────
|
||||
// After select!, the losing future is dropped, borrows released.
|
||||
// Shut down both write sides for clean TCP FIN.
|
||||
let _ = client.shutdown().await;
|
||||
let _ = server.shutdown().await;
|
||||
|
||||
// ── Final logging ───────────────────────────────────────────────
|
||||
let c2s_ops = counters.c2s_ops.load(Ordering::Relaxed);
|
||||
let s2c_ops = counters.s2c_ops.load(Ordering::Relaxed);
|
||||
let duration = epoch.elapsed();
|
||||
|
||||
match copy_result {
|
||||
Some(Ok((c2s, s2c))) => {
|
||||
// Normal completion — one side closed the connection
|
||||
debug!(
|
||||
user = %user_owned,
|
||||
c2s_bytes = c2s,
|
||||
s2c_bytes = s2c,
|
||||
c2s_msgs = c2s_ops,
|
||||
s2c_msgs = s2c_ops,
|
||||
duration_secs = duration.as_secs(),
|
||||
"Relay finished"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
// I/O error in one of the directions
|
||||
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
||||
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
|
||||
debug!(
|
||||
user = %user_owned,
|
||||
c2s_bytes = c2s,
|
||||
s2c_bytes = s2c,
|
||||
c2s_msgs = c2s_ops,
|
||||
s2c_msgs = s2c_ops,
|
||||
duration_secs = duration.as_secs(),
|
||||
error = %e,
|
||||
"Relay error"
|
||||
);
|
||||
Err(e.into())
|
||||
}
|
||||
None => {
|
||||
// Activity timeout (watchdog fired)
|
||||
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
|
||||
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
|
||||
debug!(
|
||||
user = %user_owned,
|
||||
c2s_bytes = c2s,
|
||||
s2c_bytes = s2c,
|
||||
c2s_msgs = c2s_ops,
|
||||
s2c_msgs = s2c_ops,
|
||||
duration_secs = duration.as_secs(),
|
||||
"Relay finished (activity timeout)"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user