Antireplay Improvements + DC Ping

- Fix: LruCache::get type ambiguity in stats/mod.rs
  - Changed `self.cache.get(&key.into())` to `self.cache.get(key)` (key is already &[u8], resolved via Box<[u8]>: Borrow<[u8]>)
  - Changed `self.cache.peek(&key)` / `.pop(&key)` to `.peek(key.as_ref())` / `.pop(key.as_ref())` (explicit &[u8] instead of &Box<[u8]>)

- Startup DC ping with RTT display and improved health-check (all DCs, RTT tracking, EMA latency, 30s interval):
  - Implemented `LatencyEma` – exponential moving average (α=0.3) for RTT
  - `connect()` – measures RTT of each real connection and updates EMA
  - `ping_all_dcs()` – pings all 5 DCs via each upstream, returns `Vec<StartupPingResult>` with RTT or error
  - `run_health_checks(prefer_ipv6)` – accepts IPv6 preference parameter, rotates DC between cycles (DC1→DC2→...→DC5→DC1...), interval reduced to 30s from 60s, failed checks now mark upstream as unhealthy after 3 consecutive fails
  - `DcPingResult` / `StartupPingResult` – public structures for display
  - DC Ping at startup: calls `upstream_manager.ping_all_dcs()` before accept loop, outputs table via `println!` (always visible)
  - Health checks with `prefer_ipv6`: `run_health_checks(prefer_ipv6)` receives the parameter
  - Exported `StartupPingResult` and `DcPingResult`

- Summary: Startup DC ping with RTT, rotational health-check with EMA latency tracking, 30-second interval, correct unhealthy marking after 3 fails.

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
Alexey
2026-02-07 20:18:25 +03:00
parent 92cedabc81
commit 158eae8d2a
4 changed files with 458 additions and 212 deletions

View File

@@ -26,11 +26,6 @@ use crate::transport::{create_listener, ListenOptions, UpstreamManager};
use crate::util::ip::detect_ip;
use crate::stream::BufferPool;
/// Parse command-line arguments.
///
/// Usage: telemt [config_path] [--silent] [--log-level <level>]
///
/// Returns (config_path, silent_flag, log_level_override)
fn parse_cli() -> (String, bool, Option<String>) {
let mut config_path = "config.toml".to_string();
let mut silent = false;
@@ -40,33 +35,23 @@ fn parse_cli() -> (String, bool, Option<String>) {
let mut i = 0;
while i < args.len() {
match args[i].as_str() {
"--silent" | "-s" => {
silent = true;
}
"--silent" | "-s" => { silent = true; }
"--log-level" => {
i += 1;
if i < args.len() {
log_level = Some(args[i].clone());
}
if i < args.len() { log_level = Some(args[i].clone()); }
}
s if s.starts_with("--log-level=") => {
log_level = Some(s.trim_start_matches("--log-level=").to_string());
}
"--help" | "-h" => {
eprintln!("Usage: telemt [config.toml] [OPTIONS]");
eprintln!();
eprintln!("Options:");
eprintln!(" --silent, -s Suppress info logs (only warn/error)");
eprintln!(" --log-level <LEVEL> Set log level: debug|verbose|normal|silent");
eprintln!(" --silent, -s Suppress info logs");
eprintln!(" --log-level <LEVEL> debug|verbose|normal|silent");
eprintln!(" --help, -h Show this help");
std::process::exit(0);
}
s if !s.starts_with('-') => {
config_path = s.to_string();
}
other => {
eprintln!("Unknown option: {}", other);
}
s if !s.starts_with('-') => { config_path = s.to_string(); }
other => { eprintln!("Unknown option: {}", other); }
}
i += 1;
}
@@ -76,20 +61,17 @@ fn parse_cli() -> (String, bool, Option<String>) {
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// 1. Parse CLI arguments
let (config_path, cli_silent, cli_log_level) = parse_cli();
// 2. Load config (tracing not yet initialized — errors go to stderr)
let config = match ProxyConfig::load(&config_path) {
Ok(c) => c,
Err(e) => {
if std::path::Path::new(&config_path).exists() {
eprintln!("[telemt] Error: Failed to load config '{}': {}", config_path, e);
eprintln!("[telemt] Error: {}", e);
std::process::exit(1);
} else {
let default = ProxyConfig::default();
let toml_str = toml::to_string_pretty(&default).unwrap();
std::fs::write(&config_path, toml_str).unwrap();
std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap();
eprintln!("[telemt] Created default config at {}", config_path);
default
}
@@ -97,80 +79,90 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
};
if let Err(e) = config.validate() {
eprintln!("[telemt] Error: Invalid configuration: {}", e);
eprintln!("[telemt] Invalid config: {}", e);
std::process::exit(1);
}
// 3. Determine effective log level
// Priority: RUST_LOG env > CLI flags > config file > default (normal)
let effective_log_level = if cli_silent {
LogLevel::Silent
} else if let Some(ref level_str) = cli_log_level {
LogLevel::from_str_loose(level_str)
} else if let Some(ref s) = cli_log_level {
LogLevel::from_str_loose(s)
} else {
config.general.log_level.clone()
};
// 4. Initialize tracing
let filter = if std::env::var("RUST_LOG").is_ok() {
// RUST_LOG takes absolute priority
EnvFilter::from_default_env()
} else {
EnvFilter::new(effective_log_level.to_filter_str())
};
fmt()
.with_env_filter(filter)
.init();
fmt().with_env_filter(filter).init();
// 5. Log startup info (operational — respects log level)
info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION"));
info!("Log level: {}", effective_log_level);
info!(
"Modes: classic={} secure={} tls={}",
info!("Modes: classic={} secure={} tls={}",
config.general.modes.classic,
config.general.modes.secure,
config.general.modes.tls
);
config.general.modes.tls);
info!("TLS domain: {}", config.censorship.tls_domain);
info!(
"Mask: {} -> {}:{}",
info!("Mask: {} -> {}:{}",
config.censorship.mask,
config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain),
config.censorship.mask_port
);
config.censorship.mask_port);
if config.censorship.tls_domain == "www.google.com" {
warn!("Using default tls_domain (www.google.com). Consider setting a custom domain.");
}
let prefer_ipv6 = config.general.prefer_ipv6;
let config = Arc::new(config);
let stats = Arc::new(Stats::new());
let rng = Arc::new(SecureRandom::new());
// Initialize ReplayChecker
let replay_checker = Arc::new(ReplayChecker::new(
config.access.replay_check_len,
Duration::from_secs(config.access.replay_window_secs),
));
// Initialize Upstream Manager
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
// Initialize Buffer Pool (16KB buffers, max 4096 cached ≈ 64MB)
let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096));
// Start health checks
// === Startup DC Ping ===
println!("=== Telegram DC Connectivity ===");
let ping_results = upstream_manager.ping_all_dcs(prefer_ipv6).await;
for upstream_result in &ping_results {
println!(" via {}", upstream_result.upstream_name);
for dc in &upstream_result.results {
match (&dc.rtt_ms, &dc.error) {
(Some(rtt), _) => {
println!(" DC{} ({:>21}): {:.0}ms", dc.dc_idx, dc.dc_addr, rtt);
}
(None, Some(err)) => {
println!(" DC{} ({:>21}): FAIL ({})", dc.dc_idx, dc.dc_addr, err);
}
(None, None) => {
println!(" DC{} ({:>21}): FAIL", dc.dc_idx, dc.dc_addr);
}
}
}
}
println!("================================");
// Start background tasks
let um_clone = upstream_manager.clone();
tokio::spawn(async move {
um_clone.run_health_checks().await;
um_clone.run_health_checks(prefer_ipv6).await;
});
let rc_clone = replay_checker.clone();
tokio::spawn(async move {
rc_clone.run_periodic_cleanup().await;
});
// Detect public IP (once at startup)
let detected_ip = detect_ip().await;
debug!("Detected IPs: v4={:?} v6={:?}", detected_ip.ipv4, detected_ip.ipv6);
// 6. Start listeners
let mut listeners = Vec::new();
for listener_conf in &config.server.listeners {
@@ -185,7 +177,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::from_std(socket.into())?;
info!("Listening on {}", addr);
// Determine public IP for tg:// links
let public_ip = if let Some(ip) = listener_conf.announce_ip {
ip
} else if listener_conf.ip.is_unspecified() {
@@ -198,30 +189,26 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
listener_conf.ip
};
// 7. Print proxy links (always visible — uses println!, not tracing)
if !config.show_link.is_empty() {
println!("--- Proxy Links ({}) ---", public_ip);
for user_name in &config.show_link {
if let Some(secret) = config.access.users.get(user_name) {
println!("[{}]", user_name);
if config.general.modes.classic {
println!(" Classic: tg://proxy?server={}&port={}&secret={}",
public_ip, config.server.port, secret);
}
if config.general.modes.secure {
println!(" DD: tg://proxy?server={}&port={}&secret=dd{}",
public_ip, config.server.port, secret);
}
if config.general.modes.tls {
let domain_hex = hex::encode(&config.censorship.tls_domain);
println!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
public_ip, config.server.port, secret, domain_hex);
}
} else {
warn!("User '{}' in show_link not found in users", user_name);
warn!("User '{}' in show_link not found", user_name);
}
}
println!("------------------------");
@@ -236,11 +223,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
if listeners.is_empty() {
error!("No listeners could be started. Exiting.");
error!("No listeners. Exiting.");
std::process::exit(1);
}
// 8. Accept loop
for listener in listeners {
let config = config.clone();
let stats = stats.clone();
@@ -262,14 +248,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
tokio::spawn(async move {
if let Err(e) = ClientHandler::new(
stream,
peer_addr,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng
stream, peer_addr, config, stats,
upstream_manager, replay_checker, buffer_pool, rng
).run().await {
debug!(peer = %peer_addr, error = %e, "Connection error");
}
@@ -284,7 +264,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
});
}
// 9. Wait for shutdown signal
match signal::ctrl_c().await {
Ok(()) => info!("Shutting down..."),
Err(e) => error!("Signal error: {}", e),

View File

@@ -1,32 +1,28 @@
//! Statistics
//! Statistics and replay protection
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Instant, Duration};
use dashmap::DashMap;
use parking_lot::{RwLock, Mutex};
use parking_lot::Mutex;
use lru::LruCache;
use std::num::NonZeroUsize;
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
use std::collections::VecDeque;
use tracing::debug;
// ============= Stats =============
/// Thread-safe statistics
#[derive(Default)]
pub struct Stats {
// Global counters
connects_all: AtomicU64,
connects_bad: AtomicU64,
handshake_timeouts: AtomicU64,
// Per-user stats
user_stats: DashMap<String, UserStats>,
// Start time
start_time: RwLock<Option<Instant>>,
start_time: parking_lot::RwLock<Option<Instant>>,
}
/// Per-user statistics
#[derive(Default)]
pub struct UserStats {
pub connects: AtomicU64,
@@ -44,42 +40,20 @@ impl Stats {
stats
}
// Global stats
pub fn increment_connects_all(&self) {
self.connects_all.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_connects_all(&self) { self.connects_all.fetch_add(1, Ordering::Relaxed); }
pub fn increment_connects_bad(&self) { self.connects_bad.fetch_add(1, Ordering::Relaxed); }
pub fn increment_handshake_timeouts(&self) { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); }
pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) }
pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) }
pub fn increment_connects_bad(&self) {
self.connects_bad.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_handshake_timeouts(&self) {
self.handshake_timeouts.fetch_add(1, Ordering::Relaxed);
}
pub fn get_connects_all(&self) -> u64 {
self.connects_all.load(Ordering::Relaxed)
}
pub fn get_connects_bad(&self) -> u64 {
self.connects_bad.load(Ordering::Relaxed)
}
// User stats
pub fn increment_user_connects(&self, user: &str) {
self.user_stats
.entry(user.to_string())
.or_default()
.connects
.fetch_add(1, Ordering::Relaxed);
self.user_stats.entry(user.to_string()).or_default()
.connects.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_user_curr_connects(&self, user: &str) {
self.user_stats
.entry(user.to_string())
.or_default()
.curr_connects
.fetch_add(1, Ordering::Relaxed);
self.user_stats.entry(user.to_string()).or_default()
.curr_connects.fetch_add(1, Ordering::Relaxed);
}
pub fn decrement_user_curr_connects(&self, user: &str) {
@@ -89,47 +63,33 @@ impl Stats {
}
pub fn get_user_curr_connects(&self, user: &str) -> u64 {
self.user_stats
.get(user)
self.user_stats.get(user)
.map(|s| s.curr_connects.load(Ordering::Relaxed))
.unwrap_or(0)
}
pub fn add_user_octets_from(&self, user: &str, bytes: u64) {
self.user_stats
.entry(user.to_string())
.or_default()
.octets_from_client
.fetch_add(bytes, Ordering::Relaxed);
self.user_stats.entry(user.to_string()).or_default()
.octets_from_client.fetch_add(bytes, Ordering::Relaxed);
}
pub fn add_user_octets_to(&self, user: &str, bytes: u64) {
self.user_stats
.entry(user.to_string())
.or_default()
.octets_to_client
.fetch_add(bytes, Ordering::Relaxed);
self.user_stats.entry(user.to_string()).or_default()
.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
}
pub fn increment_user_msgs_from(&self, user: &str) {
self.user_stats
.entry(user.to_string())
.or_default()
.msgs_from_client
.fetch_add(1, Ordering::Relaxed);
self.user_stats.entry(user.to_string()).or_default()
.msgs_from_client.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_user_msgs_to(&self, user: &str) {
self.user_stats
.entry(user.to_string())
.or_default()
.msgs_to_client
.fetch_add(1, Ordering::Relaxed);
self.user_stats.entry(user.to_string()).or_default()
.msgs_to_client.fetch_add(1, Ordering::Relaxed);
}
pub fn get_user_total_octets(&self, user: &str) -> u64 {
self.user_stats
.get(user)
self.user_stats.get(user)
.map(|s| {
s.octets_from_client.load(Ordering::Relaxed) +
s.octets_to_client.load(Ordering::Relaxed)
@@ -144,21 +104,27 @@ impl Stats {
}
}
/// Sharded Replay attack checker using LRU cache + sliding window
/// Uses multiple independent LRU caches to reduce lock contention
// ============= Replay Checker =============
pub struct ReplayChecker {
shards: Vec<Mutex<ReplayShard>>,
shard_mask: usize,
window: Duration,
checks: AtomicU64,
hits: AtomicU64,
additions: AtomicU64,
cleanups: AtomicU64,
}
struct ReplayEntry {
seen_at: Instant,
seq: u64,
}
struct ReplayShard {
cache: LruCache<Vec<u8>, ReplayEntry>,
queue: VecDeque<(Instant, Vec<u8>)>,
cache: LruCache<Box<[u8]>, ReplayEntry>,
queue: VecDeque<(Instant, Box<[u8]>, u64)>,
seq_counter: u64,
}
impl ReplayShard {
@@ -166,33 +132,60 @@ impl ReplayShard {
Self {
cache: LruCache::new(cap),
queue: VecDeque::with_capacity(cap.get()),
seq_counter: 0,
}
}
fn next_seq(&mut self) -> u64 {
self.seq_counter += 1;
self.seq_counter
}
fn cleanup(&mut self, now: Instant, window: Duration) {
if window.is_zero() {
return;
}
let cutoff = now - window;
while let Some((ts, _)) = self.queue.front() {
let cutoff = now.checked_sub(window).unwrap_or(now);
while let Some((ts, _, _)) = self.queue.front() {
if *ts >= cutoff {
break;
}
let (ts_old, key_old) = self.queue.pop_front().unwrap();
if let Some(entry) = self.cache.get(&key_old) {
if entry.seen_at <= ts_old {
self.cache.pop(&key_old);
}
let (_, key, queue_seq) = self.queue.pop_front().unwrap();
// Use key.as_ref() to get &[u8] — avoids Borrow<Q> ambiguity
// between Borrow<[u8]> and Borrow<Box<[u8]>>
if let Some(entry) = self.cache.peek(key.as_ref()) {
if entry.seq == queue_seq {
self.cache.pop(key.as_ref());
}
}
}
}
fn check(&mut self, key: &[u8], now: Instant, window: Duration) -> bool {
self.cleanup(now, window);
// key is &[u8], resolves Q=[u8] via Box<[u8]>: Borrow<[u8]>
self.cache.get(key).is_some()
}
fn add(&mut self, key: &[u8], now: Instant, window: Duration) {
self.cleanup(now, window);
let seq = self.next_seq();
let boxed_key: Box<[u8]> = key.into();
self.cache.put(boxed_key.clone(), ReplayEntry { seen_at: now, seq });
self.queue.push_back((now, boxed_key, seq));
}
fn len(&self) -> usize {
self.cache.len()
}
}
impl ReplayChecker {
/// Create new replay checker with specified capacity per shard
/// Total capacity = capacity * num_shards
pub fn new(total_capacity: usize, window: Duration) -> Self {
// Use 64 shards for good concurrency
let num_shards = 64;
let shard_capacity = (total_capacity / num_shards).max(1);
let cap = NonZeroUsize::new(shard_capacity).unwrap();
@@ -206,50 +199,114 @@ impl ReplayChecker {
shards,
shard_mask: num_shards - 1,
window,
checks: AtomicU64::new(0),
hits: AtomicU64::new(0),
additions: AtomicU64::new(0),
cleanups: AtomicU64::new(0),
}
}
fn get_shard(&self, key: &[u8]) -> usize {
fn get_shard_idx(&self, key: &[u8]) -> usize {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
(hasher.finish() as usize) & self.shard_mask
}
fn check(&self, data: &[u8]) -> bool {
let shard_idx = self.get_shard(data);
let mut shard = self.shards[shard_idx].lock();
let now = Instant::now();
shard.cleanup(now, self.window);
let key = data.to_vec();
shard.cache.get(&key).is_some()
self.checks.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
let mut shard = self.shards[idx].lock();
let found = shard.check(data, Instant::now(), self.window);
if found {
self.hits.fetch_add(1, Ordering::Relaxed);
}
found
}
fn add(&self, data: &[u8]) {
let shard_idx = self.get_shard(data);
let mut shard = self.shards[shard_idx].lock();
self.additions.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
let mut shard = self.shards[idx].lock();
shard.add(data, Instant::now(), self.window);
}
pub fn check_handshake(&self, data: &[u8]) -> bool { self.check(data) }
pub fn add_handshake(&self, data: &[u8]) { self.add(data) }
pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check(data) }
pub fn add_tls_digest(&self, data: &[u8]) { self.add(data) }
pub fn stats(&self) -> ReplayStats {
let mut total_entries = 0;
let mut total_queue_len = 0;
for shard in &self.shards {
let s = shard.lock();
total_entries += s.cache.len();
total_queue_len += s.queue.len();
}
ReplayStats {
total_entries,
total_queue_len,
total_checks: self.checks.load(Ordering::Relaxed),
total_hits: self.hits.load(Ordering::Relaxed),
total_additions: self.additions.load(Ordering::Relaxed),
total_cleanups: self.cleanups.load(Ordering::Relaxed),
num_shards: self.shards.len(),
window_secs: self.window.as_secs(),
}
}
pub async fn run_periodic_cleanup(&self) {
let interval = if self.window.as_secs() > 60 {
Duration::from_secs(30)
} else {
Duration::from_secs(self.window.as_secs().max(1) / 2)
};
loop {
tokio::time::sleep(interval).await;
let now = Instant::now();
let mut cleaned = 0usize;
for shard_mutex in &self.shards {
let mut shard = shard_mutex.lock();
let before = shard.len();
shard.cleanup(now, self.window);
let key = data.to_vec();
shard.cache.put(key.clone(), ReplayEntry { seen_at: now });
shard.queue.push_back((now, key));
let after = shard.len();
cleaned += before.saturating_sub(after);
}
pub fn check_handshake(&self, data: &[u8]) -> bool {
self.check(data)
self.cleanups.fetch_add(1, Ordering::Relaxed);
if cleaned > 0 {
debug!(cleaned = cleaned, "Replay checker: periodic cleanup");
}
}
}
}
pub fn add_handshake(&self, data: &[u8]) {
self.add(data)
#[derive(Debug, Clone)]
pub struct ReplayStats {
pub total_entries: usize,
pub total_queue_len: usize,
pub total_checks: u64,
pub total_hits: u64,
pub total_additions: u64,
pub total_cleanups: u64,
pub num_shards: usize,
pub window_secs: u64,
}
pub fn check_tls_digest(&self, data: &[u8]) -> bool {
self.check(data)
impl ReplayStats {
pub fn hit_rate(&self) -> f64 {
if self.total_checks == 0 { 0.0 }
else { (self.total_hits as f64 / self.total_checks as f64) * 100.0 }
}
pub fn add_tls_digest(&self, data: &[u8]) {
self.add(data)
pub fn ghost_ratio(&self) -> f64 {
if self.total_entries == 0 { 0.0 }
else { self.total_queue_len as f64 / self.total_entries as f64 }
}
}
@@ -260,28 +317,60 @@ mod tests {
#[test]
fn test_stats_shared_counters() {
let stats = Arc::new(Stats::new());
let stats1 = Arc::clone(&stats);
let stats2 = Arc::clone(&stats);
stats1.increment_connects_all();
stats2.increment_connects_all();
stats1.increment_connects_all();
stats.increment_connects_all();
stats.increment_connects_all();
stats.increment_connects_all();
assert_eq!(stats.get_connects_all(), 3);
}
#[test]
fn test_replay_checker_sharding() {
fn test_replay_checker_basic() {
let checker = ReplayChecker::new(100, Duration::from_secs(60));
let data1 = b"test1";
let data2 = b"test2";
assert!(!checker.check_handshake(b"test1"));
checker.add_handshake(b"test1");
assert!(checker.check_handshake(b"test1"));
assert!(!checker.check_handshake(b"test2"));
}
checker.add_handshake(data1);
assert!(checker.check_handshake(data1));
assert!(!checker.check_handshake(data2));
#[test]
fn test_replay_checker_duplicate_add() {
let checker = ReplayChecker::new(100, Duration::from_secs(60));
checker.add_handshake(b"dup");
checker.add_handshake(b"dup");
assert!(checker.check_handshake(b"dup"));
}
checker.add_handshake(data2);
assert!(checker.check_handshake(data2));
#[test]
fn test_replay_checker_expiration() {
let checker = ReplayChecker::new(100, Duration::from_millis(50));
checker.add_handshake(b"expire");
assert!(checker.check_handshake(b"expire"));
std::thread::sleep(Duration::from_millis(100));
assert!(!checker.check_handshake(b"expire"));
}
#[test]
fn test_replay_checker_stats() {
let checker = ReplayChecker::new(100, Duration::from_secs(60));
checker.add_handshake(b"k1");
checker.add_handshake(b"k2");
checker.check_handshake(b"k1");
checker.check_handshake(b"k3");
let stats = checker.stats();
assert_eq!(stats.total_additions, 2);
assert_eq!(stats.total_checks, 2);
assert_eq!(stats.total_hits, 1);
}
#[test]
fn test_replay_checker_many_keys() {
let checker = ReplayChecker::new(1000, Duration::from_secs(60));
for i in 0..500u32 {
checker.add(&i.to_le_bytes());
}
for i in 0..500u32 {
assert!(checker.check(&i.to_le_bytes()));
}
assert_eq!(checker.stats().total_entries, 500);
}
}

View File

@@ -10,4 +10,4 @@ pub use pool::ConnectionPool;
pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol};
pub use socket::*;
pub use socks::*;
pub use upstream::UpstreamManager;
pub use upstream::{UpstreamManager, StartupPingResult, DcPingResult};

View File

@@ -1,26 +1,78 @@
//! Upstream Management
//! Upstream Management with RTT tracking and startup ping
use std::net::{SocketAddr, IpAddr};
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::RwLock;
use tokio::time::Instant;
use rand::Rng;
use tracing::{debug, warn, error, info};
use crate::config::{UpstreamConfig, UpstreamType};
use crate::error::{Result, ProxyError};
use crate::protocol::constants::{TG_DATACENTERS_V4, TG_DATACENTERS_V6, TG_DATACENTER_PORT};
use crate::transport::socket::create_outgoing_socket_bound;
use crate::transport::socks::{connect_socks4, connect_socks5};
// ============= RTT Tracking =============
/// Exponential moving average for latency tracking
#[derive(Debug, Clone)]
struct LatencyEma {
/// Current EMA value in milliseconds (None = no data yet)
value_ms: Option<f64>,
/// Smoothing factor (0.0 - 1.0, higher = more weight to recent)
alpha: f64,
}
impl LatencyEma {
fn new(alpha: f64) -> Self {
Self { value_ms: None, alpha }
}
fn update(&mut self, sample_ms: f64) {
self.value_ms = Some(match self.value_ms {
None => sample_ms,
Some(prev) => prev * (1.0 - self.alpha) + sample_ms * self.alpha,
});
}
fn get(&self) -> Option<f64> {
self.value_ms
}
}
// ============= Upstream State =============
#[derive(Debug)]
struct UpstreamState {
config: UpstreamConfig,
healthy: bool,
fails: u32,
last_check: std::time::Instant,
/// Latency EMA (alpha=0.3 — moderate smoothing)
latency: LatencyEma,
}
/// Result of a single DC ping
#[derive(Debug, Clone)]
pub struct DcPingResult {
pub dc_idx: usize,
pub dc_addr: SocketAddr,
pub rtt_ms: Option<f64>,
pub error: Option<String>,
}
/// Result of startup ping across all DCs
#[derive(Debug, Clone)]
pub struct StartupPingResult {
pub results: Vec<DcPingResult>,
pub upstream_name: String,
}
// ============= Upstream Manager =============
#[derive(Clone)]
pub struct UpstreamManager {
upstreams: Arc<RwLock<Vec<UpstreamState>>>,
@@ -35,6 +87,7 @@ impl UpstreamManager {
healthy: true,
fails: 0,
last_check: std::time::Instant::now(),
latency: LatencyEma::new(0.3),
})
.collect();
@@ -43,7 +96,7 @@ impl UpstreamManager {
}
}
/// Select an upstream using Weighted Round Robin (simplified)
/// Select an upstream using weighted selection among healthy upstreams
async fn select_upstream(&self) -> Option<usize> {
let upstreams = self.upstreams.read().await;
if upstreams.is_empty() {
@@ -57,11 +110,9 @@ impl UpstreamManager {
.collect();
if healthy_indices.is_empty() {
// If all unhealthy, try any random one
return Some(rand::rng().gen_range(0..upstreams.len()));
}
// Weighted selection
let total_weight: u32 = healthy_indices.iter()
.map(|&i| upstreams[i].config.weight as u32)
.sum();
@@ -92,15 +143,19 @@ impl UpstreamManager {
guard[idx].config.clone()
};
let start = Instant::now();
match self.connect_via_upstream(&upstream, target).await {
Ok(stream) => {
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(idx) {
if !u.healthy {
debug!("Upstream recovered: {:?}", u.config);
debug!(rtt_ms = rtt_ms, "Upstream recovered: {:?}", u.config);
}
u.healthy = true;
u.fails = 0;
u.latency.update(rtt_ms);
}
Ok(stream)
},
@@ -108,10 +163,10 @@ impl UpstreamManager {
let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(idx) {
u.fails += 1;
warn!("Failed to connect via upstream {:?}: {}. Fails: {}", u.config, e, u.fails);
warn!("Upstream {:?} failed: {}. Consecutive fails: {}", u.config, e, u.fails);
if u.fails > 3 {
u.healthy = false;
warn!("Upstream disabled due to failures: {:?}", u.config);
warn!("Upstream marked unhealthy: {:?}", u.config);
}
}
Err(e)
@@ -145,7 +200,7 @@ impl UpstreamManager {
Ok(stream)
},
UpstreamType::Socks4 { address, interface, user_id } => {
info!("Connecting to target {} via SOCKS4 proxy {}", target, address);
info!("Connecting to {} via SOCKS4 {}", target, address);
let proxy_addr: SocketAddr = address.parse()
.map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?;
@@ -174,7 +229,7 @@ impl UpstreamManager {
Ok(stream)
},
UpstreamType::Socks5 { address, interface, username, password } => {
info!("Connecting to target {} via SOCKS5 proxy {}", target, address);
info!("Connecting to {} via SOCKS5 {}", target, address);
let proxy_addr: SocketAddr = address.parse()
.map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?;
@@ -205,12 +260,109 @@ impl UpstreamManager {
}
}
/// Background task to check health
pub async fn run_health_checks(&self) {
let check_target: SocketAddr = "149.154.167.50:443".parse().unwrap();
// ============= Startup Ping =============
/// Ping all Telegram DCs through all upstreams and return results.
///
/// Used at startup to display connectivity and latency info.
pub async fn ping_all_dcs(&self, prefer_ipv6: bool) -> Vec<StartupPingResult> {
let upstreams: Vec<(usize, UpstreamConfig)> = {
let guard = self.upstreams.read().await;
guard.iter().enumerate()
.map(|(i, u)| (i, u.config.clone()))
.collect()
};
let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 };
let mut all_results = Vec::new();
for (upstream_idx, upstream_config) in &upstreams {
let upstream_name = match &upstream_config.upstream_type {
UpstreamType::Direct { interface } => {
format!("direct{}", interface.as_ref().map(|i| format!(" ({})", i)).unwrap_or_default())
}
UpstreamType::Socks4 { address, .. } => format!("socks4://{}", address),
UpstreamType::Socks5 { address, .. } => format!("socks5://{}", address),
};
let mut dc_results = Vec::new();
for (dc_zero_idx, dc_ip) in datacenters.iter().enumerate() {
let dc_addr = SocketAddr::new(*dc_ip, TG_DATACENTER_PORT);
let ping_result = tokio::time::timeout(
Duration::from_secs(5),
self.ping_single_dc(upstream_config, dc_addr)
).await;
let result = match ping_result {
Ok(Ok(rtt_ms)) => {
// Update latency EMA
let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(*upstream_idx) {
u.latency.update(rtt_ms);
}
DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr,
rtt_ms: Some(rtt_ms),
error: None,
}
}
Ok(Err(e)) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr,
rtt_ms: None,
error: Some(e.to_string()),
},
Err(_) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr,
rtt_ms: None,
error: Some("timeout (5s)".to_string()),
},
};
dc_results.push(result);
}
all_results.push(StartupPingResult {
results: dc_results,
upstream_name,
});
}
all_results
}
/// Ping a single DC: TCP connect, measure RTT, then drop.
async fn ping_single_dc(&self, config: &UpstreamConfig, target: SocketAddr) -> Result<f64> {
let start = Instant::now();
let _stream = self.connect_via_upstream(config, target).await?;
let rtt = start.elapsed();
Ok(rtt.as_secs_f64() * 1000.0)
}
// ============= Health Checks =============
/// Background health check task.
///
/// Every 30 seconds, pings one representative DC per upstream.
/// Measures RTT and updates health status.
pub async fn run_health_checks(&self, prefer_ipv6: bool) {
let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 };
// Rotate through DCs across check cycles
let mut dc_rotation = 0usize;
loop {
tokio::time::sleep(Duration::from_secs(60)).await;
tokio::time::sleep(Duration::from_secs(30)).await;
let check_dc_idx = dc_rotation % datacenters.len();
dc_rotation += 1;
let check_target = SocketAddr::new(datacenters[check_dc_idx], TG_DATACENTER_PORT);
let count = self.upstreams.read().await.len();
for i in 0..count {
@@ -219,6 +371,7 @@ impl UpstreamManager {
guard[i].config.clone()
};
let start = Instant::now();
let result = tokio::time::timeout(
Duration::from_secs(10),
self.connect_via_upstream(&config, check_target)
@@ -229,17 +382,42 @@ impl UpstreamManager {
match result {
Ok(Ok(_stream)) => {
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
u.latency.update(rtt_ms);
if !u.healthy {
debug!("Upstream recovered: {:?}", u.config);
info!(
rtt_ms = format!("{:.1}", rtt_ms),
dc = check_dc_idx + 1,
"Upstream recovered: {:?}", u.config
);
}
u.healthy = true;
u.fails = 0;
}
Ok(Err(e)) => {
debug!("Health check failed for {:?}: {}", u.config, e);
u.fails += 1;
debug!(
dc = check_dc_idx + 1,
fails = u.fails,
"Health check failed for {:?}: {}", u.config, e
);
if u.fails > 3 {
u.healthy = false;
warn!("Upstream unhealthy (health check): {:?}", u.config);
}
}
Err(_) => {
debug!("Health check timeout for {:?}", u.config);
u.fails += 1;
debug!(
dc = check_dc_idx + 1,
fails = u.fails,
"Health check timeout for {:?}", u.config
);
if u.fails > 3 {
u.healthy = false;
warn!("Upstream unhealthy (timeout): {:?}", u.config);
}
}
}
u.last_check = std::time::Instant::now();