Antireplay Improvements + DC Ping
- Fix: LruCache::get type ambiguity in stats/mod.rs - Changed `self.cache.get(&key.into())` to `self.cache.get(key)` (key is already &[u8], resolved via Box<[u8]>: Borrow<[u8]>) - Changed `self.cache.peek(&key)` / `.pop(&key)` to `.peek(key.as_ref())` / `.pop(key.as_ref())` (explicit &[u8] instead of &Box<[u8]>) - Startup DC ping with RTT display and improved health-check (all DCs, RTT tracking, EMA latency, 30s interval): - Implemented `LatencyEma` – exponential moving average (α=0.3) for RTT - `connect()` – measures RTT of each real connection and updates EMA - `ping_all_dcs()` – pings all 5 DCs via each upstream, returns `Vec<StartupPingResult>` with RTT or error - `run_health_checks(prefer_ipv6)` – accepts IPv6 preference parameter, rotates DC between cycles (DC1→DC2→...→DC5→DC1...), interval reduced to 30s from 60s, failed checks now mark upstream as unhealthy after 3 consecutive fails - `DcPingResult` / `StartupPingResult` – public structures for display - DC Ping at startup: calls `upstream_manager.ping_all_dcs()` before accept loop, outputs table via `println!` (always visible) - Health checks with `prefer_ipv6`: `run_health_checks(prefer_ipv6)` receives the parameter - Exported `StartupPingResult` and `DcPingResult` - Summary: Startup DC ping with RTT, rotational health-check with EMA latency tracking, 30-second interval, correct unhealthy marking after 3 fails. Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
119
src/main.rs
119
src/main.rs
@@ -26,11 +26,6 @@ use crate::transport::{create_listener, ListenOptions, UpstreamManager};
|
||||
use crate::util::ip::detect_ip;
|
||||
use crate::stream::BufferPool;
|
||||
|
||||
/// Parse command-line arguments.
|
||||
///
|
||||
/// Usage: telemt [config_path] [--silent] [--log-level <level>]
|
||||
///
|
||||
/// Returns (config_path, silent_flag, log_level_override)
|
||||
fn parse_cli() -> (String, bool, Option<String>) {
|
||||
let mut config_path = "config.toml".to_string();
|
||||
let mut silent = false;
|
||||
@@ -40,33 +35,23 @@ fn parse_cli() -> (String, bool, Option<String>) {
|
||||
let mut i = 0;
|
||||
while i < args.len() {
|
||||
match args[i].as_str() {
|
||||
"--silent" | "-s" => {
|
||||
silent = true;
|
||||
}
|
||||
"--silent" | "-s" => { silent = true; }
|
||||
"--log-level" => {
|
||||
i += 1;
|
||||
if i < args.len() {
|
||||
log_level = Some(args[i].clone());
|
||||
}
|
||||
if i < args.len() { log_level = Some(args[i].clone()); }
|
||||
}
|
||||
s if s.starts_with("--log-level=") => {
|
||||
log_level = Some(s.trim_start_matches("--log-level=").to_string());
|
||||
}
|
||||
"--help" | "-h" => {
|
||||
eprintln!("Usage: telemt [config.toml] [OPTIONS]");
|
||||
eprintln!();
|
||||
eprintln!("Options:");
|
||||
eprintln!(" --silent, -s Suppress info logs (only warn/error)");
|
||||
eprintln!(" --log-level <LEVEL> Set log level: debug|verbose|normal|silent");
|
||||
eprintln!(" --silent, -s Suppress info logs");
|
||||
eprintln!(" --log-level <LEVEL> debug|verbose|normal|silent");
|
||||
eprintln!(" --help, -h Show this help");
|
||||
std::process::exit(0);
|
||||
}
|
||||
s if !s.starts_with('-') => {
|
||||
config_path = s.to_string();
|
||||
}
|
||||
other => {
|
||||
eprintln!("Unknown option: {}", other);
|
||||
}
|
||||
s if !s.starts_with('-') => { config_path = s.to_string(); }
|
||||
other => { eprintln!("Unknown option: {}", other); }
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
@@ -76,20 +61,17 @@ fn parse_cli() -> (String, bool, Option<String>) {
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// 1. Parse CLI arguments
|
||||
let (config_path, cli_silent, cli_log_level) = parse_cli();
|
||||
|
||||
// 2. Load config (tracing not yet initialized — errors go to stderr)
|
||||
let config = match ProxyConfig::load(&config_path) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
if std::path::Path::new(&config_path).exists() {
|
||||
eprintln!("[telemt] Error: Failed to load config '{}': {}", config_path, e);
|
||||
eprintln!("[telemt] Error: {}", e);
|
||||
std::process::exit(1);
|
||||
} else {
|
||||
let default = ProxyConfig::default();
|
||||
let toml_str = toml::to_string_pretty(&default).unwrap();
|
||||
std::fs::write(&config_path, toml_str).unwrap();
|
||||
std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap();
|
||||
eprintln!("[telemt] Created default config at {}", config_path);
|
||||
default
|
||||
}
|
||||
@@ -97,80 +79,90 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
};
|
||||
|
||||
if let Err(e) = config.validate() {
|
||||
eprintln!("[telemt] Error: Invalid configuration: {}", e);
|
||||
eprintln!("[telemt] Invalid config: {}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// 3. Determine effective log level
|
||||
// Priority: RUST_LOG env > CLI flags > config file > default (normal)
|
||||
let effective_log_level = if cli_silent {
|
||||
LogLevel::Silent
|
||||
} else if let Some(ref level_str) = cli_log_level {
|
||||
LogLevel::from_str_loose(level_str)
|
||||
} else if let Some(ref s) = cli_log_level {
|
||||
LogLevel::from_str_loose(s)
|
||||
} else {
|
||||
config.general.log_level.clone()
|
||||
};
|
||||
|
||||
// 4. Initialize tracing
|
||||
let filter = if std::env::var("RUST_LOG").is_ok() {
|
||||
// RUST_LOG takes absolute priority
|
||||
EnvFilter::from_default_env()
|
||||
} else {
|
||||
EnvFilter::new(effective_log_level.to_filter_str())
|
||||
};
|
||||
|
||||
fmt()
|
||||
.with_env_filter(filter)
|
||||
.init();
|
||||
fmt().with_env_filter(filter).init();
|
||||
|
||||
// 5. Log startup info (operational — respects log level)
|
||||
info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION"));
|
||||
info!("Log level: {}", effective_log_level);
|
||||
info!(
|
||||
"Modes: classic={} secure={} tls={}",
|
||||
info!("Modes: classic={} secure={} tls={}",
|
||||
config.general.modes.classic,
|
||||
config.general.modes.secure,
|
||||
config.general.modes.tls
|
||||
);
|
||||
config.general.modes.tls);
|
||||
info!("TLS domain: {}", config.censorship.tls_domain);
|
||||
info!(
|
||||
"Mask: {} -> {}:{}",
|
||||
info!("Mask: {} -> {}:{}",
|
||||
config.censorship.mask,
|
||||
config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain),
|
||||
config.censorship.mask_port
|
||||
);
|
||||
config.censorship.mask_port);
|
||||
|
||||
if config.censorship.tls_domain == "www.google.com" {
|
||||
warn!("Using default tls_domain (www.google.com). Consider setting a custom domain.");
|
||||
}
|
||||
|
||||
let prefer_ipv6 = config.general.prefer_ipv6;
|
||||
let config = Arc::new(config);
|
||||
let stats = Arc::new(Stats::new());
|
||||
let rng = Arc::new(SecureRandom::new());
|
||||
|
||||
// Initialize ReplayChecker
|
||||
let replay_checker = Arc::new(ReplayChecker::new(
|
||||
config.access.replay_check_len,
|
||||
Duration::from_secs(config.access.replay_window_secs),
|
||||
));
|
||||
|
||||
// Initialize Upstream Manager
|
||||
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
|
||||
|
||||
// Initialize Buffer Pool (16KB buffers, max 4096 cached ≈ 64MB)
|
||||
let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096));
|
||||
|
||||
// Start health checks
|
||||
// === Startup DC Ping ===
|
||||
println!("=== Telegram DC Connectivity ===");
|
||||
let ping_results = upstream_manager.ping_all_dcs(prefer_ipv6).await;
|
||||
for upstream_result in &ping_results {
|
||||
println!(" via {}", upstream_result.upstream_name);
|
||||
for dc in &upstream_result.results {
|
||||
match (&dc.rtt_ms, &dc.error) {
|
||||
(Some(rtt), _) => {
|
||||
println!(" DC{} ({:>21}): {:.0}ms", dc.dc_idx, dc.dc_addr, rtt);
|
||||
}
|
||||
(None, Some(err)) => {
|
||||
println!(" DC{} ({:>21}): FAIL ({})", dc.dc_idx, dc.dc_addr, err);
|
||||
}
|
||||
(None, None) => {
|
||||
println!(" DC{} ({:>21}): FAIL", dc.dc_idx, dc.dc_addr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
println!("================================");
|
||||
|
||||
// Start background tasks
|
||||
let um_clone = upstream_manager.clone();
|
||||
tokio::spawn(async move {
|
||||
um_clone.run_health_checks().await;
|
||||
um_clone.run_health_checks(prefer_ipv6).await;
|
||||
});
|
||||
|
||||
let rc_clone = replay_checker.clone();
|
||||
tokio::spawn(async move {
|
||||
rc_clone.run_periodic_cleanup().await;
|
||||
});
|
||||
|
||||
// Detect public IP (once at startup)
|
||||
let detected_ip = detect_ip().await;
|
||||
debug!("Detected IPs: v4={:?} v6={:?}", detected_ip.ipv4, detected_ip.ipv6);
|
||||
|
||||
// 6. Start listeners
|
||||
let mut listeners = Vec::new();
|
||||
|
||||
for listener_conf in &config.server.listeners {
|
||||
@@ -185,7 +177,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let listener = TcpListener::from_std(socket.into())?;
|
||||
info!("Listening on {}", addr);
|
||||
|
||||
// Determine public IP for tg:// links
|
||||
let public_ip = if let Some(ip) = listener_conf.announce_ip {
|
||||
ip
|
||||
} else if listener_conf.ip.is_unspecified() {
|
||||
@@ -198,30 +189,26 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
listener_conf.ip
|
||||
};
|
||||
|
||||
// 7. Print proxy links (always visible — uses println!, not tracing)
|
||||
if !config.show_link.is_empty() {
|
||||
println!("--- Proxy Links ({}) ---", public_ip);
|
||||
for user_name in &config.show_link {
|
||||
if let Some(secret) = config.access.users.get(user_name) {
|
||||
println!("[{}]", user_name);
|
||||
|
||||
if config.general.modes.classic {
|
||||
println!(" Classic: tg://proxy?server={}&port={}&secret={}",
|
||||
public_ip, config.server.port, secret);
|
||||
}
|
||||
|
||||
if config.general.modes.secure {
|
||||
println!(" DD: tg://proxy?server={}&port={}&secret=dd{}",
|
||||
public_ip, config.server.port, secret);
|
||||
}
|
||||
|
||||
if config.general.modes.tls {
|
||||
let domain_hex = hex::encode(&config.censorship.tls_domain);
|
||||
println!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
|
||||
public_ip, config.server.port, secret, domain_hex);
|
||||
}
|
||||
} else {
|
||||
warn!("User '{}' in show_link not found in users", user_name);
|
||||
warn!("User '{}' in show_link not found", user_name);
|
||||
}
|
||||
}
|
||||
println!("------------------------");
|
||||
@@ -236,11 +223,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
}
|
||||
|
||||
if listeners.is_empty() {
|
||||
error!("No listeners could be started. Exiting.");
|
||||
error!("No listeners. Exiting.");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// 8. Accept loop
|
||||
for listener in listeners {
|
||||
let config = config.clone();
|
||||
let stats = stats.clone();
|
||||
@@ -262,14 +248,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = ClientHandler::new(
|
||||
stream,
|
||||
peer_addr,
|
||||
config,
|
||||
stats,
|
||||
upstream_manager,
|
||||
replay_checker,
|
||||
buffer_pool,
|
||||
rng
|
||||
stream, peer_addr, config, stats,
|
||||
upstream_manager, replay_checker, buffer_pool, rng
|
||||
).run().await {
|
||||
debug!(peer = %peer_addr, error = %e, "Connection error");
|
||||
}
|
||||
@@ -284,7 +264,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
});
|
||||
}
|
||||
|
||||
// 9. Wait for shutdown signal
|
||||
match signal::ctrl_c().await {
|
||||
Ok(()) => info!("Shutting down..."),
|
||||
Err(e) => error!("Signal error: {}", e),
|
||||
|
||||
329
src/stats/mod.rs
329
src/stats/mod.rs
@@ -1,32 +1,28 @@
|
||||
//! Statistics
|
||||
//! Statistics and replay protection
|
||||
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Instant, Duration};
|
||||
use dashmap::DashMap;
|
||||
use parking_lot::{RwLock, Mutex};
|
||||
use parking_lot::Mutex;
|
||||
use lru::LruCache;
|
||||
use std::num::NonZeroUsize;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::collections::VecDeque;
|
||||
use tracing::debug;
|
||||
|
||||
// ============= Stats =============
|
||||
|
||||
/// Thread-safe statistics
|
||||
#[derive(Default)]
|
||||
pub struct Stats {
|
||||
// Global counters
|
||||
connects_all: AtomicU64,
|
||||
connects_bad: AtomicU64,
|
||||
handshake_timeouts: AtomicU64,
|
||||
|
||||
// Per-user stats
|
||||
user_stats: DashMap<String, UserStats>,
|
||||
|
||||
// Start time
|
||||
start_time: RwLock<Option<Instant>>,
|
||||
start_time: parking_lot::RwLock<Option<Instant>>,
|
||||
}
|
||||
|
||||
/// Per-user statistics
|
||||
#[derive(Default)]
|
||||
pub struct UserStats {
|
||||
pub connects: AtomicU64,
|
||||
@@ -44,42 +40,20 @@ impl Stats {
|
||||
stats
|
||||
}
|
||||
|
||||
// Global stats
|
||||
pub fn increment_connects_all(&self) {
|
||||
self.connects_all.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
pub fn increment_connects_all(&self) { self.connects_all.fetch_add(1, Ordering::Relaxed); }
|
||||
pub fn increment_connects_bad(&self) { self.connects_bad.fetch_add(1, Ordering::Relaxed); }
|
||||
pub fn increment_handshake_timeouts(&self) { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); }
|
||||
pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) }
|
||||
pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) }
|
||||
|
||||
pub fn increment_connects_bad(&self) {
|
||||
self.connects_bad.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_handshake_timeouts(&self) {
|
||||
self.handshake_timeouts.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn get_connects_all(&self) -> u64 {
|
||||
self.connects_all.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn get_connects_bad(&self) -> u64 {
|
||||
self.connects_bad.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
// User stats
|
||||
pub fn increment_user_connects(&self, user: &str) {
|
||||
self.user_stats
|
||||
.entry(user.to_string())
|
||||
.or_default()
|
||||
.connects
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
self.user_stats.entry(user.to_string()).or_default()
|
||||
.connects.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_user_curr_connects(&self, user: &str) {
|
||||
self.user_stats
|
||||
.entry(user.to_string())
|
||||
.or_default()
|
||||
.curr_connects
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
self.user_stats.entry(user.to_string()).or_default()
|
||||
.curr_connects.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn decrement_user_curr_connects(&self, user: &str) {
|
||||
@@ -89,47 +63,33 @@ impl Stats {
|
||||
}
|
||||
|
||||
pub fn get_user_curr_connects(&self, user: &str) -> u64 {
|
||||
self.user_stats
|
||||
.get(user)
|
||||
self.user_stats.get(user)
|
||||
.map(|s| s.curr_connects.load(Ordering::Relaxed))
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn add_user_octets_from(&self, user: &str, bytes: u64) {
|
||||
self.user_stats
|
||||
.entry(user.to_string())
|
||||
.or_default()
|
||||
.octets_from_client
|
||||
.fetch_add(bytes, Ordering::Relaxed);
|
||||
self.user_stats.entry(user.to_string()).or_default()
|
||||
.octets_from_client.fetch_add(bytes, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn add_user_octets_to(&self, user: &str, bytes: u64) {
|
||||
self.user_stats
|
||||
.entry(user.to_string())
|
||||
.or_default()
|
||||
.octets_to_client
|
||||
.fetch_add(bytes, Ordering::Relaxed);
|
||||
self.user_stats.entry(user.to_string()).or_default()
|
||||
.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_user_msgs_from(&self, user: &str) {
|
||||
self.user_stats
|
||||
.entry(user.to_string())
|
||||
.or_default()
|
||||
.msgs_from_client
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
self.user_stats.entry(user.to_string()).or_default()
|
||||
.msgs_from_client.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_user_msgs_to(&self, user: &str) {
|
||||
self.user_stats
|
||||
.entry(user.to_string())
|
||||
.or_default()
|
||||
.msgs_to_client
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
self.user_stats.entry(user.to_string()).or_default()
|
||||
.msgs_to_client.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn get_user_total_octets(&self, user: &str) -> u64 {
|
||||
self.user_stats
|
||||
.get(user)
|
||||
self.user_stats.get(user)
|
||||
.map(|s| {
|
||||
s.octets_from_client.load(Ordering::Relaxed) +
|
||||
s.octets_to_client.load(Ordering::Relaxed)
|
||||
@@ -144,21 +104,27 @@ impl Stats {
|
||||
}
|
||||
}
|
||||
|
||||
/// Sharded Replay attack checker using LRU cache + sliding window
|
||||
/// Uses multiple independent LRU caches to reduce lock contention
|
||||
// ============= Replay Checker =============
|
||||
|
||||
pub struct ReplayChecker {
|
||||
shards: Vec<Mutex<ReplayShard>>,
|
||||
shard_mask: usize,
|
||||
window: Duration,
|
||||
checks: AtomicU64,
|
||||
hits: AtomicU64,
|
||||
additions: AtomicU64,
|
||||
cleanups: AtomicU64,
|
||||
}
|
||||
|
||||
struct ReplayEntry {
|
||||
seen_at: Instant,
|
||||
seq: u64,
|
||||
}
|
||||
|
||||
struct ReplayShard {
|
||||
cache: LruCache<Vec<u8>, ReplayEntry>,
|
||||
queue: VecDeque<(Instant, Vec<u8>)>,
|
||||
cache: LruCache<Box<[u8]>, ReplayEntry>,
|
||||
queue: VecDeque<(Instant, Box<[u8]>, u64)>,
|
||||
seq_counter: u64,
|
||||
}
|
||||
|
||||
impl ReplayShard {
|
||||
@@ -166,33 +132,60 @@ impl ReplayShard {
|
||||
Self {
|
||||
cache: LruCache::new(cap),
|
||||
queue: VecDeque::with_capacity(cap.get()),
|
||||
seq_counter: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn next_seq(&mut self) -> u64 {
|
||||
self.seq_counter += 1;
|
||||
self.seq_counter
|
||||
}
|
||||
|
||||
fn cleanup(&mut self, now: Instant, window: Duration) {
|
||||
if window.is_zero() {
|
||||
return;
|
||||
}
|
||||
let cutoff = now - window;
|
||||
while let Some((ts, _)) = self.queue.front() {
|
||||
let cutoff = now.checked_sub(window).unwrap_or(now);
|
||||
|
||||
while let Some((ts, _, _)) = self.queue.front() {
|
||||
if *ts >= cutoff {
|
||||
break;
|
||||
}
|
||||
let (ts_old, key_old) = self.queue.pop_front().unwrap();
|
||||
if let Some(entry) = self.cache.get(&key_old) {
|
||||
if entry.seen_at <= ts_old {
|
||||
self.cache.pop(&key_old);
|
||||
let (_, key, queue_seq) = self.queue.pop_front().unwrap();
|
||||
|
||||
// Use key.as_ref() to get &[u8] — avoids Borrow<Q> ambiguity
|
||||
// between Borrow<[u8]> and Borrow<Box<[u8]>>
|
||||
if let Some(entry) = self.cache.peek(key.as_ref()) {
|
||||
if entry.seq == queue_seq {
|
||||
self.cache.pop(key.as_ref());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn check(&mut self, key: &[u8], now: Instant, window: Duration) -> bool {
|
||||
self.cleanup(now, window);
|
||||
// key is &[u8], resolves Q=[u8] via Box<[u8]>: Borrow<[u8]>
|
||||
self.cache.get(key).is_some()
|
||||
}
|
||||
|
||||
fn add(&mut self, key: &[u8], now: Instant, window: Duration) {
|
||||
self.cleanup(now, window);
|
||||
|
||||
let seq = self.next_seq();
|
||||
let boxed_key: Box<[u8]> = key.into();
|
||||
|
||||
self.cache.put(boxed_key.clone(), ReplayEntry { seen_at: now, seq });
|
||||
self.queue.push_back((now, boxed_key, seq));
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.cache.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl ReplayChecker {
|
||||
/// Create new replay checker with specified capacity per shard
|
||||
/// Total capacity = capacity * num_shards
|
||||
pub fn new(total_capacity: usize, window: Duration) -> Self {
|
||||
// Use 64 shards for good concurrency
|
||||
let num_shards = 64;
|
||||
let shard_capacity = (total_capacity / num_shards).max(1);
|
||||
let cap = NonZeroUsize::new(shard_capacity).unwrap();
|
||||
@@ -206,50 +199,114 @@ impl ReplayChecker {
|
||||
shards,
|
||||
shard_mask: num_shards - 1,
|
||||
window,
|
||||
checks: AtomicU64::new(0),
|
||||
hits: AtomicU64::new(0),
|
||||
additions: AtomicU64::new(0),
|
||||
cleanups: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_shard(&self, key: &[u8]) -> usize {
|
||||
fn get_shard_idx(&self, key: &[u8]) -> usize {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
key.hash(&mut hasher);
|
||||
(hasher.finish() as usize) & self.shard_mask
|
||||
}
|
||||
|
||||
fn check(&self, data: &[u8]) -> bool {
|
||||
let shard_idx = self.get_shard(data);
|
||||
let mut shard = self.shards[shard_idx].lock();
|
||||
let now = Instant::now();
|
||||
shard.cleanup(now, self.window);
|
||||
|
||||
let key = data.to_vec();
|
||||
shard.cache.get(&key).is_some()
|
||||
self.checks.fetch_add(1, Ordering::Relaxed);
|
||||
let idx = self.get_shard_idx(data);
|
||||
let mut shard = self.shards[idx].lock();
|
||||
let found = shard.check(data, Instant::now(), self.window);
|
||||
if found {
|
||||
self.hits.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
found
|
||||
}
|
||||
|
||||
fn add(&self, data: &[u8]) {
|
||||
let shard_idx = self.get_shard(data);
|
||||
let mut shard = self.shards[shard_idx].lock();
|
||||
let now = Instant::now();
|
||||
shard.cleanup(now, self.window);
|
||||
|
||||
let key = data.to_vec();
|
||||
shard.cache.put(key.clone(), ReplayEntry { seen_at: now });
|
||||
shard.queue.push_back((now, key));
|
||||
self.additions.fetch_add(1, Ordering::Relaxed);
|
||||
let idx = self.get_shard_idx(data);
|
||||
let mut shard = self.shards[idx].lock();
|
||||
shard.add(data, Instant::now(), self.window);
|
||||
}
|
||||
|
||||
pub fn check_handshake(&self, data: &[u8]) -> bool {
|
||||
self.check(data)
|
||||
pub fn check_handshake(&self, data: &[u8]) -> bool { self.check(data) }
|
||||
pub fn add_handshake(&self, data: &[u8]) { self.add(data) }
|
||||
pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check(data) }
|
||||
pub fn add_tls_digest(&self, data: &[u8]) { self.add(data) }
|
||||
|
||||
pub fn stats(&self) -> ReplayStats {
|
||||
let mut total_entries = 0;
|
||||
let mut total_queue_len = 0;
|
||||
for shard in &self.shards {
|
||||
let s = shard.lock();
|
||||
total_entries += s.cache.len();
|
||||
total_queue_len += s.queue.len();
|
||||
}
|
||||
|
||||
ReplayStats {
|
||||
total_entries,
|
||||
total_queue_len,
|
||||
total_checks: self.checks.load(Ordering::Relaxed),
|
||||
total_hits: self.hits.load(Ordering::Relaxed),
|
||||
total_additions: self.additions.load(Ordering::Relaxed),
|
||||
total_cleanups: self.cleanups.load(Ordering::Relaxed),
|
||||
num_shards: self.shards.len(),
|
||||
window_secs: self.window.as_secs(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_handshake(&self, data: &[u8]) {
|
||||
self.add(data)
|
||||
pub async fn run_periodic_cleanup(&self) {
|
||||
let interval = if self.window.as_secs() > 60 {
|
||||
Duration::from_secs(30)
|
||||
} else {
|
||||
Duration::from_secs(self.window.as_secs().max(1) / 2)
|
||||
};
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
|
||||
let now = Instant::now();
|
||||
let mut cleaned = 0usize;
|
||||
|
||||
for shard_mutex in &self.shards {
|
||||
let mut shard = shard_mutex.lock();
|
||||
let before = shard.len();
|
||||
shard.cleanup(now, self.window);
|
||||
let after = shard.len();
|
||||
cleaned += before.saturating_sub(after);
|
||||
}
|
||||
|
||||
self.cleanups.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
if cleaned > 0 {
|
||||
debug!(cleaned = cleaned, "Replay checker: periodic cleanup");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ReplayStats {
|
||||
pub total_entries: usize,
|
||||
pub total_queue_len: usize,
|
||||
pub total_checks: u64,
|
||||
pub total_hits: u64,
|
||||
pub total_additions: u64,
|
||||
pub total_cleanups: u64,
|
||||
pub num_shards: usize,
|
||||
pub window_secs: u64,
|
||||
}
|
||||
|
||||
impl ReplayStats {
|
||||
pub fn hit_rate(&self) -> f64 {
|
||||
if self.total_checks == 0 { 0.0 }
|
||||
else { (self.total_hits as f64 / self.total_checks as f64) * 100.0 }
|
||||
}
|
||||
|
||||
pub fn check_tls_digest(&self, data: &[u8]) -> bool {
|
||||
self.check(data)
|
||||
}
|
||||
|
||||
pub fn add_tls_digest(&self, data: &[u8]) {
|
||||
self.add(data)
|
||||
pub fn ghost_ratio(&self) -> f64 {
|
||||
if self.total_entries == 0 { 0.0 }
|
||||
else { self.total_queue_len as f64 / self.total_entries as f64 }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -260,28 +317,60 @@ mod tests {
|
||||
#[test]
|
||||
fn test_stats_shared_counters() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let stats1 = Arc::clone(&stats);
|
||||
let stats2 = Arc::clone(&stats);
|
||||
|
||||
stats1.increment_connects_all();
|
||||
stats2.increment_connects_all();
|
||||
stats1.increment_connects_all();
|
||||
|
||||
stats.increment_connects_all();
|
||||
stats.increment_connects_all();
|
||||
stats.increment_connects_all();
|
||||
assert_eq!(stats.get_connects_all(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replay_checker_sharding() {
|
||||
fn test_replay_checker_basic() {
|
||||
let checker = ReplayChecker::new(100, Duration::from_secs(60));
|
||||
let data1 = b"test1";
|
||||
let data2 = b"test2";
|
||||
assert!(!checker.check_handshake(b"test1"));
|
||||
checker.add_handshake(b"test1");
|
||||
assert!(checker.check_handshake(b"test1"));
|
||||
assert!(!checker.check_handshake(b"test2"));
|
||||
}
|
||||
|
||||
checker.add_handshake(data1);
|
||||
assert!(checker.check_handshake(data1));
|
||||
assert!(!checker.check_handshake(data2));
|
||||
#[test]
|
||||
fn test_replay_checker_duplicate_add() {
|
||||
let checker = ReplayChecker::new(100, Duration::from_secs(60));
|
||||
checker.add_handshake(b"dup");
|
||||
checker.add_handshake(b"dup");
|
||||
assert!(checker.check_handshake(b"dup"));
|
||||
}
|
||||
|
||||
checker.add_handshake(data2);
|
||||
assert!(checker.check_handshake(data2));
|
||||
#[test]
|
||||
fn test_replay_checker_expiration() {
|
||||
let checker = ReplayChecker::new(100, Duration::from_millis(50));
|
||||
checker.add_handshake(b"expire");
|
||||
assert!(checker.check_handshake(b"expire"));
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
assert!(!checker.check_handshake(b"expire"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replay_checker_stats() {
|
||||
let checker = ReplayChecker::new(100, Duration::from_secs(60));
|
||||
checker.add_handshake(b"k1");
|
||||
checker.add_handshake(b"k2");
|
||||
checker.check_handshake(b"k1");
|
||||
checker.check_handshake(b"k3");
|
||||
let stats = checker.stats();
|
||||
assert_eq!(stats.total_additions, 2);
|
||||
assert_eq!(stats.total_checks, 2);
|
||||
assert_eq!(stats.total_hits, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replay_checker_many_keys() {
|
||||
let checker = ReplayChecker::new(1000, Duration::from_secs(60));
|
||||
for i in 0..500u32 {
|
||||
checker.add(&i.to_le_bytes());
|
||||
}
|
||||
for i in 0..500u32 {
|
||||
assert!(checker.check(&i.to_le_bytes()));
|
||||
}
|
||||
assert_eq!(checker.stats().total_entries, 500);
|
||||
}
|
||||
}
|
||||
@@ -10,4 +10,4 @@ pub use pool::ConnectionPool;
|
||||
pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol};
|
||||
pub use socket::*;
|
||||
pub use socks::*;
|
||||
pub use upstream::UpstreamManager;
|
||||
pub use upstream::{UpstreamManager, StartupPingResult, DcPingResult};
|
||||
@@ -1,26 +1,78 @@
|
||||
//! Upstream Management
|
||||
//! Upstream Management with RTT tracking and startup ping
|
||||
|
||||
use std::net::{SocketAddr, IpAddr};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::time::Instant;
|
||||
use rand::Rng;
|
||||
use tracing::{debug, warn, error, info};
|
||||
|
||||
use crate::config::{UpstreamConfig, UpstreamType};
|
||||
use crate::error::{Result, ProxyError};
|
||||
use crate::protocol::constants::{TG_DATACENTERS_V4, TG_DATACENTERS_V6, TG_DATACENTER_PORT};
|
||||
use crate::transport::socket::create_outgoing_socket_bound;
|
||||
use crate::transport::socks::{connect_socks4, connect_socks5};
|
||||
|
||||
// ============= RTT Tracking =============
|
||||
|
||||
/// Exponential moving average for latency tracking
|
||||
#[derive(Debug, Clone)]
|
||||
struct LatencyEma {
|
||||
/// Current EMA value in milliseconds (None = no data yet)
|
||||
value_ms: Option<f64>,
|
||||
/// Smoothing factor (0.0 - 1.0, higher = more weight to recent)
|
||||
alpha: f64,
|
||||
}
|
||||
|
||||
impl LatencyEma {
|
||||
fn new(alpha: f64) -> Self {
|
||||
Self { value_ms: None, alpha }
|
||||
}
|
||||
|
||||
fn update(&mut self, sample_ms: f64) {
|
||||
self.value_ms = Some(match self.value_ms {
|
||||
None => sample_ms,
|
||||
Some(prev) => prev * (1.0 - self.alpha) + sample_ms * self.alpha,
|
||||
});
|
||||
}
|
||||
|
||||
fn get(&self) -> Option<f64> {
|
||||
self.value_ms
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Upstream State =============
|
||||
|
||||
#[derive(Debug)]
|
||||
struct UpstreamState {
|
||||
config: UpstreamConfig,
|
||||
healthy: bool,
|
||||
fails: u32,
|
||||
last_check: std::time::Instant,
|
||||
/// Latency EMA (alpha=0.3 — moderate smoothing)
|
||||
latency: LatencyEma,
|
||||
}
|
||||
|
||||
/// Result of a single DC ping
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DcPingResult {
|
||||
pub dc_idx: usize,
|
||||
pub dc_addr: SocketAddr,
|
||||
pub rtt_ms: Option<f64>,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// Result of startup ping across all DCs
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StartupPingResult {
|
||||
pub results: Vec<DcPingResult>,
|
||||
pub upstream_name: String,
|
||||
}
|
||||
|
||||
// ============= Upstream Manager =============
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct UpstreamManager {
|
||||
upstreams: Arc<RwLock<Vec<UpstreamState>>>,
|
||||
@@ -35,6 +87,7 @@ impl UpstreamManager {
|
||||
healthy: true,
|
||||
fails: 0,
|
||||
last_check: std::time::Instant::now(),
|
||||
latency: LatencyEma::new(0.3),
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -43,7 +96,7 @@ impl UpstreamManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Select an upstream using Weighted Round Robin (simplified)
|
||||
/// Select an upstream using weighted selection among healthy upstreams
|
||||
async fn select_upstream(&self) -> Option<usize> {
|
||||
let upstreams = self.upstreams.read().await;
|
||||
if upstreams.is_empty() {
|
||||
@@ -57,11 +110,9 @@ impl UpstreamManager {
|
||||
.collect();
|
||||
|
||||
if healthy_indices.is_empty() {
|
||||
// If all unhealthy, try any random one
|
||||
return Some(rand::rng().gen_range(0..upstreams.len()));
|
||||
}
|
||||
|
||||
// Weighted selection
|
||||
let total_weight: u32 = healthy_indices.iter()
|
||||
.map(|&i| upstreams[i].config.weight as u32)
|
||||
.sum();
|
||||
@@ -92,15 +143,19 @@ impl UpstreamManager {
|
||||
guard[idx].config.clone()
|
||||
};
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
match self.connect_via_upstream(&upstream, target).await {
|
||||
Ok(stream) => {
|
||||
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
|
||||
let mut guard = self.upstreams.write().await;
|
||||
if let Some(u) = guard.get_mut(idx) {
|
||||
if !u.healthy {
|
||||
debug!("Upstream recovered: {:?}", u.config);
|
||||
debug!(rtt_ms = rtt_ms, "Upstream recovered: {:?}", u.config);
|
||||
}
|
||||
u.healthy = true;
|
||||
u.fails = 0;
|
||||
u.latency.update(rtt_ms);
|
||||
}
|
||||
Ok(stream)
|
||||
},
|
||||
@@ -108,10 +163,10 @@ impl UpstreamManager {
|
||||
let mut guard = self.upstreams.write().await;
|
||||
if let Some(u) = guard.get_mut(idx) {
|
||||
u.fails += 1;
|
||||
warn!("Failed to connect via upstream {:?}: {}. Fails: {}", u.config, e, u.fails);
|
||||
warn!("Upstream {:?} failed: {}. Consecutive fails: {}", u.config, e, u.fails);
|
||||
if u.fails > 3 {
|
||||
u.healthy = false;
|
||||
warn!("Upstream disabled due to failures: {:?}", u.config);
|
||||
warn!("Upstream marked unhealthy: {:?}", u.config);
|
||||
}
|
||||
}
|
||||
Err(e)
|
||||
@@ -145,7 +200,7 @@ impl UpstreamManager {
|
||||
Ok(stream)
|
||||
},
|
||||
UpstreamType::Socks4 { address, interface, user_id } => {
|
||||
info!("Connecting to target {} via SOCKS4 proxy {}", target, address);
|
||||
info!("Connecting to {} via SOCKS4 {}", target, address);
|
||||
|
||||
let proxy_addr: SocketAddr = address.parse()
|
||||
.map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?;
|
||||
@@ -174,7 +229,7 @@ impl UpstreamManager {
|
||||
Ok(stream)
|
||||
},
|
||||
UpstreamType::Socks5 { address, interface, username, password } => {
|
||||
info!("Connecting to target {} via SOCKS5 proxy {}", target, address);
|
||||
info!("Connecting to {} via SOCKS5 {}", target, address);
|
||||
|
||||
let proxy_addr: SocketAddr = address.parse()
|
||||
.map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?;
|
||||
@@ -205,12 +260,109 @@ impl UpstreamManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Background task to check health
|
||||
pub async fn run_health_checks(&self) {
|
||||
let check_target: SocketAddr = "149.154.167.50:443".parse().unwrap();
|
||||
// ============= Startup Ping =============
|
||||
|
||||
/// Ping all Telegram DCs through all upstreams and return results.
|
||||
///
|
||||
/// Used at startup to display connectivity and latency info.
|
||||
pub async fn ping_all_dcs(&self, prefer_ipv6: bool) -> Vec<StartupPingResult> {
|
||||
let upstreams: Vec<(usize, UpstreamConfig)> = {
|
||||
let guard = self.upstreams.read().await;
|
||||
guard.iter().enumerate()
|
||||
.map(|(i, u)| (i, u.config.clone()))
|
||||
.collect()
|
||||
};
|
||||
|
||||
let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 };
|
||||
|
||||
let mut all_results = Vec::new();
|
||||
|
||||
for (upstream_idx, upstream_config) in &upstreams {
|
||||
let upstream_name = match &upstream_config.upstream_type {
|
||||
UpstreamType::Direct { interface } => {
|
||||
format!("direct{}", interface.as_ref().map(|i| format!(" ({})", i)).unwrap_or_default())
|
||||
}
|
||||
UpstreamType::Socks4 { address, .. } => format!("socks4://{}", address),
|
||||
UpstreamType::Socks5 { address, .. } => format!("socks5://{}", address),
|
||||
};
|
||||
|
||||
let mut dc_results = Vec::new();
|
||||
|
||||
for (dc_zero_idx, dc_ip) in datacenters.iter().enumerate() {
|
||||
let dc_addr = SocketAddr::new(*dc_ip, TG_DATACENTER_PORT);
|
||||
|
||||
let ping_result = tokio::time::timeout(
|
||||
Duration::from_secs(5),
|
||||
self.ping_single_dc(upstream_config, dc_addr)
|
||||
).await;
|
||||
|
||||
let result = match ping_result {
|
||||
Ok(Ok(rtt_ms)) => {
|
||||
// Update latency EMA
|
||||
let mut guard = self.upstreams.write().await;
|
||||
if let Some(u) = guard.get_mut(*upstream_idx) {
|
||||
u.latency.update(rtt_ms);
|
||||
}
|
||||
DcPingResult {
|
||||
dc_idx: dc_zero_idx + 1,
|
||||
dc_addr,
|
||||
rtt_ms: Some(rtt_ms),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
Ok(Err(e)) => DcPingResult {
|
||||
dc_idx: dc_zero_idx + 1,
|
||||
dc_addr,
|
||||
rtt_ms: None,
|
||||
error: Some(e.to_string()),
|
||||
},
|
||||
Err(_) => DcPingResult {
|
||||
dc_idx: dc_zero_idx + 1,
|
||||
dc_addr,
|
||||
rtt_ms: None,
|
||||
error: Some("timeout (5s)".to_string()),
|
||||
},
|
||||
};
|
||||
|
||||
dc_results.push(result);
|
||||
}
|
||||
|
||||
all_results.push(StartupPingResult {
|
||||
results: dc_results,
|
||||
upstream_name,
|
||||
});
|
||||
}
|
||||
|
||||
all_results
|
||||
}
|
||||
|
||||
/// Ping a single DC: TCP connect, measure RTT, then drop.
|
||||
async fn ping_single_dc(&self, config: &UpstreamConfig, target: SocketAddr) -> Result<f64> {
|
||||
let start = Instant::now();
|
||||
let _stream = self.connect_via_upstream(config, target).await?;
|
||||
let rtt = start.elapsed();
|
||||
Ok(rtt.as_secs_f64() * 1000.0)
|
||||
}
|
||||
|
||||
// ============= Health Checks =============
|
||||
|
||||
/// Background health check task.
|
||||
///
|
||||
/// Every 30 seconds, pings one representative DC per upstream.
|
||||
/// Measures RTT and updates health status.
|
||||
pub async fn run_health_checks(&self, prefer_ipv6: bool) {
|
||||
let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 };
|
||||
|
||||
// Rotate through DCs across check cycles
|
||||
let mut dc_rotation = 0usize;
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||
|
||||
let check_dc_idx = dc_rotation % datacenters.len();
|
||||
dc_rotation += 1;
|
||||
|
||||
let check_target = SocketAddr::new(datacenters[check_dc_idx], TG_DATACENTER_PORT);
|
||||
|
||||
let count = self.upstreams.read().await.len();
|
||||
for i in 0..count {
|
||||
@@ -219,6 +371,7 @@ impl UpstreamManager {
|
||||
guard[i].config.clone()
|
||||
};
|
||||
|
||||
let start = Instant::now();
|
||||
let result = tokio::time::timeout(
|
||||
Duration::from_secs(10),
|
||||
self.connect_via_upstream(&config, check_target)
|
||||
@@ -229,17 +382,42 @@ impl UpstreamManager {
|
||||
|
||||
match result {
|
||||
Ok(Ok(_stream)) => {
|
||||
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
|
||||
u.latency.update(rtt_ms);
|
||||
|
||||
if !u.healthy {
|
||||
debug!("Upstream recovered: {:?}", u.config);
|
||||
info!(
|
||||
rtt_ms = format!("{:.1}", rtt_ms),
|
||||
dc = check_dc_idx + 1,
|
||||
"Upstream recovered: {:?}", u.config
|
||||
);
|
||||
}
|
||||
u.healthy = true;
|
||||
u.fails = 0;
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
debug!("Health check failed for {:?}: {}", u.config, e);
|
||||
u.fails += 1;
|
||||
debug!(
|
||||
dc = check_dc_idx + 1,
|
||||
fails = u.fails,
|
||||
"Health check failed for {:?}: {}", u.config, e
|
||||
);
|
||||
if u.fails > 3 {
|
||||
u.healthy = false;
|
||||
warn!("Upstream unhealthy (health check): {:?}", u.config);
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
debug!("Health check timeout for {:?}", u.config);
|
||||
u.fails += 1;
|
||||
debug!(
|
||||
dc = check_dc_idx + 1,
|
||||
fails = u.fails,
|
||||
"Health check timeout for {:?}", u.config
|
||||
);
|
||||
if u.fails > 3 {
|
||||
u.healthy = false;
|
||||
warn!("Upstream unhealthy (timeout): {:?}", u.config);
|
||||
}
|
||||
}
|
||||
}
|
||||
u.last_check = std::time::Instant::now();
|
||||
|
||||
Reference in New Issue
Block a user