Antireplay Improvements + DC Ping
- Fix: LruCache::get type ambiguity in stats/mod.rs - Changed `self.cache.get(&key.into())` to `self.cache.get(key)` (key is already &[u8], resolved via Box<[u8]>: Borrow<[u8]>) - Changed `self.cache.peek(&key)` / `.pop(&key)` to `.peek(key.as_ref())` / `.pop(key.as_ref())` (explicit &[u8] instead of &Box<[u8]>) - Startup DC ping with RTT display and improved health-check (all DCs, RTT tracking, EMA latency, 30s interval): - Implemented `LatencyEma` – exponential moving average (α=0.3) for RTT - `connect()` – measures RTT of each real connection and updates EMA - `ping_all_dcs()` – pings all 5 DCs via each upstream, returns `Vec<StartupPingResult>` with RTT or error - `run_health_checks(prefer_ipv6)` – accepts IPv6 preference parameter, rotates DC between cycles (DC1→DC2→...→DC5→DC1...), interval reduced to 30s from 60s, failed checks now mark upstream as unhealthy after 3 consecutive fails - `DcPingResult` / `StartupPingResult` – public structures for display - DC Ping at startup: calls `upstream_manager.ping_all_dcs()` before accept loop, outputs table via `println!` (always visible) - Health checks with `prefer_ipv6`: `run_health_checks(prefer_ipv6)` receives the parameter - Exported `StartupPingResult` and `DcPingResult` - Summary: Startup DC ping with RTT, rotational health-check with EMA latency tracking, 30-second interval, correct unhealthy marking after 3 fails. Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
119
src/main.rs
119
src/main.rs
@@ -26,11 +26,6 @@ use crate::transport::{create_listener, ListenOptions, UpstreamManager};
|
|||||||
use crate::util::ip::detect_ip;
|
use crate::util::ip::detect_ip;
|
||||||
use crate::stream::BufferPool;
|
use crate::stream::BufferPool;
|
||||||
|
|
||||||
/// Parse command-line arguments.
|
|
||||||
///
|
|
||||||
/// Usage: telemt [config_path] [--silent] [--log-level <level>]
|
|
||||||
///
|
|
||||||
/// Returns (config_path, silent_flag, log_level_override)
|
|
||||||
fn parse_cli() -> (String, bool, Option<String>) {
|
fn parse_cli() -> (String, bool, Option<String>) {
|
||||||
let mut config_path = "config.toml".to_string();
|
let mut config_path = "config.toml".to_string();
|
||||||
let mut silent = false;
|
let mut silent = false;
|
||||||
@@ -40,33 +35,23 @@ fn parse_cli() -> (String, bool, Option<String>) {
|
|||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
while i < args.len() {
|
while i < args.len() {
|
||||||
match args[i].as_str() {
|
match args[i].as_str() {
|
||||||
"--silent" | "-s" => {
|
"--silent" | "-s" => { silent = true; }
|
||||||
silent = true;
|
|
||||||
}
|
|
||||||
"--log-level" => {
|
"--log-level" => {
|
||||||
i += 1;
|
i += 1;
|
||||||
if i < args.len() {
|
if i < args.len() { log_level = Some(args[i].clone()); }
|
||||||
log_level = Some(args[i].clone());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
s if s.starts_with("--log-level=") => {
|
s if s.starts_with("--log-level=") => {
|
||||||
log_level = Some(s.trim_start_matches("--log-level=").to_string());
|
log_level = Some(s.trim_start_matches("--log-level=").to_string());
|
||||||
}
|
}
|
||||||
"--help" | "-h" => {
|
"--help" | "-h" => {
|
||||||
eprintln!("Usage: telemt [config.toml] [OPTIONS]");
|
eprintln!("Usage: telemt [config.toml] [OPTIONS]");
|
||||||
eprintln!();
|
eprintln!(" --silent, -s Suppress info logs");
|
||||||
eprintln!("Options:");
|
eprintln!(" --log-level <LEVEL> debug|verbose|normal|silent");
|
||||||
eprintln!(" --silent, -s Suppress info logs (only warn/error)");
|
|
||||||
eprintln!(" --log-level <LEVEL> Set log level: debug|verbose|normal|silent");
|
|
||||||
eprintln!(" --help, -h Show this help");
|
eprintln!(" --help, -h Show this help");
|
||||||
std::process::exit(0);
|
std::process::exit(0);
|
||||||
}
|
}
|
||||||
s if !s.starts_with('-') => {
|
s if !s.starts_with('-') => { config_path = s.to_string(); }
|
||||||
config_path = s.to_string();
|
other => { eprintln!("Unknown option: {}", other); }
|
||||||
}
|
|
||||||
other => {
|
|
||||||
eprintln!("Unknown option: {}", other);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
i += 1;
|
i += 1;
|
||||||
}
|
}
|
||||||
@@ -76,20 +61,17 @@ fn parse_cli() -> (String, bool, Option<String>) {
|
|||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// 1. Parse CLI arguments
|
|
||||||
let (config_path, cli_silent, cli_log_level) = parse_cli();
|
let (config_path, cli_silent, cli_log_level) = parse_cli();
|
||||||
|
|
||||||
// 2. Load config (tracing not yet initialized — errors go to stderr)
|
|
||||||
let config = match ProxyConfig::load(&config_path) {
|
let config = match ProxyConfig::load(&config_path) {
|
||||||
Ok(c) => c,
|
Ok(c) => c,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
if std::path::Path::new(&config_path).exists() {
|
if std::path::Path::new(&config_path).exists() {
|
||||||
eprintln!("[telemt] Error: Failed to load config '{}': {}", config_path, e);
|
eprintln!("[telemt] Error: {}", e);
|
||||||
std::process::exit(1);
|
std::process::exit(1);
|
||||||
} else {
|
} else {
|
||||||
let default = ProxyConfig::default();
|
let default = ProxyConfig::default();
|
||||||
let toml_str = toml::to_string_pretty(&default).unwrap();
|
std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap();
|
||||||
std::fs::write(&config_path, toml_str).unwrap();
|
|
||||||
eprintln!("[telemt] Created default config at {}", config_path);
|
eprintln!("[telemt] Created default config at {}", config_path);
|
||||||
default
|
default
|
||||||
}
|
}
|
||||||
@@ -97,80 +79,90 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if let Err(e) = config.validate() {
|
if let Err(e) = config.validate() {
|
||||||
eprintln!("[telemt] Error: Invalid configuration: {}", e);
|
eprintln!("[telemt] Invalid config: {}", e);
|
||||||
std::process::exit(1);
|
std::process::exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. Determine effective log level
|
|
||||||
// Priority: RUST_LOG env > CLI flags > config file > default (normal)
|
|
||||||
let effective_log_level = if cli_silent {
|
let effective_log_level = if cli_silent {
|
||||||
LogLevel::Silent
|
LogLevel::Silent
|
||||||
} else if let Some(ref level_str) = cli_log_level {
|
} else if let Some(ref s) = cli_log_level {
|
||||||
LogLevel::from_str_loose(level_str)
|
LogLevel::from_str_loose(s)
|
||||||
} else {
|
} else {
|
||||||
config.general.log_level.clone()
|
config.general.log_level.clone()
|
||||||
};
|
};
|
||||||
|
|
||||||
// 4. Initialize tracing
|
|
||||||
let filter = if std::env::var("RUST_LOG").is_ok() {
|
let filter = if std::env::var("RUST_LOG").is_ok() {
|
||||||
// RUST_LOG takes absolute priority
|
|
||||||
EnvFilter::from_default_env()
|
EnvFilter::from_default_env()
|
||||||
} else {
|
} else {
|
||||||
EnvFilter::new(effective_log_level.to_filter_str())
|
EnvFilter::new(effective_log_level.to_filter_str())
|
||||||
};
|
};
|
||||||
|
|
||||||
fmt()
|
fmt().with_env_filter(filter).init();
|
||||||
.with_env_filter(filter)
|
|
||||||
.init();
|
|
||||||
|
|
||||||
// 5. Log startup info (operational — respects log level)
|
|
||||||
info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION"));
|
info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION"));
|
||||||
info!("Log level: {}", effective_log_level);
|
info!("Log level: {}", effective_log_level);
|
||||||
info!(
|
info!("Modes: classic={} secure={} tls={}",
|
||||||
"Modes: classic={} secure={} tls={}",
|
|
||||||
config.general.modes.classic,
|
config.general.modes.classic,
|
||||||
config.general.modes.secure,
|
config.general.modes.secure,
|
||||||
config.general.modes.tls
|
config.general.modes.tls);
|
||||||
);
|
|
||||||
info!("TLS domain: {}", config.censorship.tls_domain);
|
info!("TLS domain: {}", config.censorship.tls_domain);
|
||||||
info!(
|
info!("Mask: {} -> {}:{}",
|
||||||
"Mask: {} -> {}:{}",
|
|
||||||
config.censorship.mask,
|
config.censorship.mask,
|
||||||
config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain),
|
config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain),
|
||||||
config.censorship.mask_port
|
config.censorship.mask_port);
|
||||||
);
|
|
||||||
|
|
||||||
if config.censorship.tls_domain == "www.google.com" {
|
if config.censorship.tls_domain == "www.google.com" {
|
||||||
warn!("Using default tls_domain (www.google.com). Consider setting a custom domain.");
|
warn!("Using default tls_domain (www.google.com). Consider setting a custom domain.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let prefer_ipv6 = config.general.prefer_ipv6;
|
||||||
let config = Arc::new(config);
|
let config = Arc::new(config);
|
||||||
let stats = Arc::new(Stats::new());
|
let stats = Arc::new(Stats::new());
|
||||||
let rng = Arc::new(SecureRandom::new());
|
let rng = Arc::new(SecureRandom::new());
|
||||||
|
|
||||||
// Initialize ReplayChecker
|
|
||||||
let replay_checker = Arc::new(ReplayChecker::new(
|
let replay_checker = Arc::new(ReplayChecker::new(
|
||||||
config.access.replay_check_len,
|
config.access.replay_check_len,
|
||||||
Duration::from_secs(config.access.replay_window_secs),
|
Duration::from_secs(config.access.replay_window_secs),
|
||||||
));
|
));
|
||||||
|
|
||||||
// Initialize Upstream Manager
|
|
||||||
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
|
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
|
||||||
|
|
||||||
// Initialize Buffer Pool (16KB buffers, max 4096 cached ≈ 64MB)
|
|
||||||
let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096));
|
let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096));
|
||||||
|
|
||||||
// Start health checks
|
// === Startup DC Ping ===
|
||||||
|
println!("=== Telegram DC Connectivity ===");
|
||||||
|
let ping_results = upstream_manager.ping_all_dcs(prefer_ipv6).await;
|
||||||
|
for upstream_result in &ping_results {
|
||||||
|
println!(" via {}", upstream_result.upstream_name);
|
||||||
|
for dc in &upstream_result.results {
|
||||||
|
match (&dc.rtt_ms, &dc.error) {
|
||||||
|
(Some(rtt), _) => {
|
||||||
|
println!(" DC{} ({:>21}): {:.0}ms", dc.dc_idx, dc.dc_addr, rtt);
|
||||||
|
}
|
||||||
|
(None, Some(err)) => {
|
||||||
|
println!(" DC{} ({:>21}): FAIL ({})", dc.dc_idx, dc.dc_addr, err);
|
||||||
|
}
|
||||||
|
(None, None) => {
|
||||||
|
println!(" DC{} ({:>21}): FAIL", dc.dc_idx, dc.dc_addr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
println!("================================");
|
||||||
|
|
||||||
|
// Start background tasks
|
||||||
let um_clone = upstream_manager.clone();
|
let um_clone = upstream_manager.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
um_clone.run_health_checks().await;
|
um_clone.run_health_checks(prefer_ipv6).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let rc_clone = replay_checker.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
rc_clone.run_periodic_cleanup().await;
|
||||||
});
|
});
|
||||||
|
|
||||||
// Detect public IP (once at startup)
|
|
||||||
let detected_ip = detect_ip().await;
|
let detected_ip = detect_ip().await;
|
||||||
debug!("Detected IPs: v4={:?} v6={:?}", detected_ip.ipv4, detected_ip.ipv6);
|
debug!("Detected IPs: v4={:?} v6={:?}", detected_ip.ipv4, detected_ip.ipv6);
|
||||||
|
|
||||||
// 6. Start listeners
|
|
||||||
let mut listeners = Vec::new();
|
let mut listeners = Vec::new();
|
||||||
|
|
||||||
for listener_conf in &config.server.listeners {
|
for listener_conf in &config.server.listeners {
|
||||||
@@ -185,7 +177,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
let listener = TcpListener::from_std(socket.into())?;
|
let listener = TcpListener::from_std(socket.into())?;
|
||||||
info!("Listening on {}", addr);
|
info!("Listening on {}", addr);
|
||||||
|
|
||||||
// Determine public IP for tg:// links
|
|
||||||
let public_ip = if let Some(ip) = listener_conf.announce_ip {
|
let public_ip = if let Some(ip) = listener_conf.announce_ip {
|
||||||
ip
|
ip
|
||||||
} else if listener_conf.ip.is_unspecified() {
|
} else if listener_conf.ip.is_unspecified() {
|
||||||
@@ -198,30 +189,26 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
listener_conf.ip
|
listener_conf.ip
|
||||||
};
|
};
|
||||||
|
|
||||||
// 7. Print proxy links (always visible — uses println!, not tracing)
|
|
||||||
if !config.show_link.is_empty() {
|
if !config.show_link.is_empty() {
|
||||||
println!("--- Proxy Links ({}) ---", public_ip);
|
println!("--- Proxy Links ({}) ---", public_ip);
|
||||||
for user_name in &config.show_link {
|
for user_name in &config.show_link {
|
||||||
if let Some(secret) = config.access.users.get(user_name) {
|
if let Some(secret) = config.access.users.get(user_name) {
|
||||||
println!("[{}]", user_name);
|
println!("[{}]", user_name);
|
||||||
|
|
||||||
if config.general.modes.classic {
|
if config.general.modes.classic {
|
||||||
println!(" Classic: tg://proxy?server={}&port={}&secret={}",
|
println!(" Classic: tg://proxy?server={}&port={}&secret={}",
|
||||||
public_ip, config.server.port, secret);
|
public_ip, config.server.port, secret);
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.general.modes.secure {
|
if config.general.modes.secure {
|
||||||
println!(" DD: tg://proxy?server={}&port={}&secret=dd{}",
|
println!(" DD: tg://proxy?server={}&port={}&secret=dd{}",
|
||||||
public_ip, config.server.port, secret);
|
public_ip, config.server.port, secret);
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.general.modes.tls {
|
if config.general.modes.tls {
|
||||||
let domain_hex = hex::encode(&config.censorship.tls_domain);
|
let domain_hex = hex::encode(&config.censorship.tls_domain);
|
||||||
println!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
|
println!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
|
||||||
public_ip, config.server.port, secret, domain_hex);
|
public_ip, config.server.port, secret, domain_hex);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
warn!("User '{}' in show_link not found in users", user_name);
|
warn!("User '{}' in show_link not found", user_name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
println!("------------------------");
|
println!("------------------------");
|
||||||
@@ -236,11 +223,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if listeners.is_empty() {
|
if listeners.is_empty() {
|
||||||
error!("No listeners could be started. Exiting.");
|
error!("No listeners. Exiting.");
|
||||||
std::process::exit(1);
|
std::process::exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 8. Accept loop
|
|
||||||
for listener in listeners {
|
for listener in listeners {
|
||||||
let config = config.clone();
|
let config = config.clone();
|
||||||
let stats = stats.clone();
|
let stats = stats.clone();
|
||||||
@@ -262,14 +248,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(e) = ClientHandler::new(
|
if let Err(e) = ClientHandler::new(
|
||||||
stream,
|
stream, peer_addr, config, stats,
|
||||||
peer_addr,
|
upstream_manager, replay_checker, buffer_pool, rng
|
||||||
config,
|
|
||||||
stats,
|
|
||||||
upstream_manager,
|
|
||||||
replay_checker,
|
|
||||||
buffer_pool,
|
|
||||||
rng
|
|
||||||
).run().await {
|
).run().await {
|
||||||
debug!(peer = %peer_addr, error = %e, "Connection error");
|
debug!(peer = %peer_addr, error = %e, "Connection error");
|
||||||
}
|
}
|
||||||
@@ -284,7 +264,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// 9. Wait for shutdown signal
|
|
||||||
match signal::ctrl_c().await {
|
match signal::ctrl_c().await {
|
||||||
Ok(()) => info!("Shutting down..."),
|
Ok(()) => info!("Shutting down..."),
|
||||||
Err(e) => error!("Signal error: {}", e),
|
Err(e) => error!("Signal error: {}", e),
|
||||||
|
|||||||
323
src/stats/mod.rs
323
src/stats/mod.rs
@@ -1,32 +1,28 @@
|
|||||||
//! Statistics
|
//! Statistics and replay protection
|
||||||
|
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Instant, Duration};
|
use std::time::{Instant, Duration};
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use parking_lot::{RwLock, Mutex};
|
use parking_lot::Mutex;
|
||||||
use lru::LruCache;
|
use lru::LruCache;
|
||||||
use std::num::NonZeroUsize;
|
use std::num::NonZeroUsize;
|
||||||
use std::hash::{Hash, Hasher};
|
use std::hash::{Hash, Hasher};
|
||||||
use std::collections::hash_map::DefaultHasher;
|
use std::collections::hash_map::DefaultHasher;
|
||||||
use std::collections::VecDeque;
|
use std::collections::VecDeque;
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
|
// ============= Stats =============
|
||||||
|
|
||||||
/// Thread-safe statistics
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct Stats {
|
pub struct Stats {
|
||||||
// Global counters
|
|
||||||
connects_all: AtomicU64,
|
connects_all: AtomicU64,
|
||||||
connects_bad: AtomicU64,
|
connects_bad: AtomicU64,
|
||||||
handshake_timeouts: AtomicU64,
|
handshake_timeouts: AtomicU64,
|
||||||
|
|
||||||
// Per-user stats
|
|
||||||
user_stats: DashMap<String, UserStats>,
|
user_stats: DashMap<String, UserStats>,
|
||||||
|
start_time: parking_lot::RwLock<Option<Instant>>,
|
||||||
// Start time
|
|
||||||
start_time: RwLock<Option<Instant>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Per-user statistics
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct UserStats {
|
pub struct UserStats {
|
||||||
pub connects: AtomicU64,
|
pub connects: AtomicU64,
|
||||||
@@ -44,42 +40,20 @@ impl Stats {
|
|||||||
stats
|
stats
|
||||||
}
|
}
|
||||||
|
|
||||||
// Global stats
|
pub fn increment_connects_all(&self) { self.connects_all.fetch_add(1, Ordering::Relaxed); }
|
||||||
pub fn increment_connects_all(&self) {
|
pub fn increment_connects_bad(&self) { self.connects_bad.fetch_add(1, Ordering::Relaxed); }
|
||||||
self.connects_all.fetch_add(1, Ordering::Relaxed);
|
pub fn increment_handshake_timeouts(&self) { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); }
|
||||||
}
|
pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) }
|
||||||
|
pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) }
|
||||||
|
|
||||||
pub fn increment_connects_bad(&self) {
|
|
||||||
self.connects_bad.fetch_add(1, Ordering::Relaxed);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn increment_handshake_timeouts(&self) {
|
|
||||||
self.handshake_timeouts.fetch_add(1, Ordering::Relaxed);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_connects_all(&self) -> u64 {
|
|
||||||
self.connects_all.load(Ordering::Relaxed)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_connects_bad(&self) -> u64 {
|
|
||||||
self.connects_bad.load(Ordering::Relaxed)
|
|
||||||
}
|
|
||||||
|
|
||||||
// User stats
|
|
||||||
pub fn increment_user_connects(&self, user: &str) {
|
pub fn increment_user_connects(&self, user: &str) {
|
||||||
self.user_stats
|
self.user_stats.entry(user.to_string()).or_default()
|
||||||
.entry(user.to_string())
|
.connects.fetch_add(1, Ordering::Relaxed);
|
||||||
.or_default()
|
|
||||||
.connects
|
|
||||||
.fetch_add(1, Ordering::Relaxed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn increment_user_curr_connects(&self, user: &str) {
|
pub fn increment_user_curr_connects(&self, user: &str) {
|
||||||
self.user_stats
|
self.user_stats.entry(user.to_string()).or_default()
|
||||||
.entry(user.to_string())
|
.curr_connects.fetch_add(1, Ordering::Relaxed);
|
||||||
.or_default()
|
|
||||||
.curr_connects
|
|
||||||
.fetch_add(1, Ordering::Relaxed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decrement_user_curr_connects(&self, user: &str) {
|
pub fn decrement_user_curr_connects(&self, user: &str) {
|
||||||
@@ -89,47 +63,33 @@ impl Stats {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_user_curr_connects(&self, user: &str) -> u64 {
|
pub fn get_user_curr_connects(&self, user: &str) -> u64 {
|
||||||
self.user_stats
|
self.user_stats.get(user)
|
||||||
.get(user)
|
|
||||||
.map(|s| s.curr_connects.load(Ordering::Relaxed))
|
.map(|s| s.curr_connects.load(Ordering::Relaxed))
|
||||||
.unwrap_or(0)
|
.unwrap_or(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_user_octets_from(&self, user: &str, bytes: u64) {
|
pub fn add_user_octets_from(&self, user: &str, bytes: u64) {
|
||||||
self.user_stats
|
self.user_stats.entry(user.to_string()).or_default()
|
||||||
.entry(user.to_string())
|
.octets_from_client.fetch_add(bytes, Ordering::Relaxed);
|
||||||
.or_default()
|
|
||||||
.octets_from_client
|
|
||||||
.fetch_add(bytes, Ordering::Relaxed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_user_octets_to(&self, user: &str, bytes: u64) {
|
pub fn add_user_octets_to(&self, user: &str, bytes: u64) {
|
||||||
self.user_stats
|
self.user_stats.entry(user.to_string()).or_default()
|
||||||
.entry(user.to_string())
|
.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
|
||||||
.or_default()
|
|
||||||
.octets_to_client
|
|
||||||
.fetch_add(bytes, Ordering::Relaxed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn increment_user_msgs_from(&self, user: &str) {
|
pub fn increment_user_msgs_from(&self, user: &str) {
|
||||||
self.user_stats
|
self.user_stats.entry(user.to_string()).or_default()
|
||||||
.entry(user.to_string())
|
.msgs_from_client.fetch_add(1, Ordering::Relaxed);
|
||||||
.or_default()
|
|
||||||
.msgs_from_client
|
|
||||||
.fetch_add(1, Ordering::Relaxed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn increment_user_msgs_to(&self, user: &str) {
|
pub fn increment_user_msgs_to(&self, user: &str) {
|
||||||
self.user_stats
|
self.user_stats.entry(user.to_string()).or_default()
|
||||||
.entry(user.to_string())
|
.msgs_to_client.fetch_add(1, Ordering::Relaxed);
|
||||||
.or_default()
|
|
||||||
.msgs_to_client
|
|
||||||
.fetch_add(1, Ordering::Relaxed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_user_total_octets(&self, user: &str) -> u64 {
|
pub fn get_user_total_octets(&self, user: &str) -> u64 {
|
||||||
self.user_stats
|
self.user_stats.get(user)
|
||||||
.get(user)
|
|
||||||
.map(|s| {
|
.map(|s| {
|
||||||
s.octets_from_client.load(Ordering::Relaxed) +
|
s.octets_from_client.load(Ordering::Relaxed) +
|
||||||
s.octets_to_client.load(Ordering::Relaxed)
|
s.octets_to_client.load(Ordering::Relaxed)
|
||||||
@@ -144,21 +104,27 @@ impl Stats {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sharded Replay attack checker using LRU cache + sliding window
|
// ============= Replay Checker =============
|
||||||
/// Uses multiple independent LRU caches to reduce lock contention
|
|
||||||
pub struct ReplayChecker {
|
pub struct ReplayChecker {
|
||||||
shards: Vec<Mutex<ReplayShard>>,
|
shards: Vec<Mutex<ReplayShard>>,
|
||||||
shard_mask: usize,
|
shard_mask: usize,
|
||||||
window: Duration,
|
window: Duration,
|
||||||
|
checks: AtomicU64,
|
||||||
|
hits: AtomicU64,
|
||||||
|
additions: AtomicU64,
|
||||||
|
cleanups: AtomicU64,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ReplayEntry {
|
struct ReplayEntry {
|
||||||
seen_at: Instant,
|
seen_at: Instant,
|
||||||
|
seq: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ReplayShard {
|
struct ReplayShard {
|
||||||
cache: LruCache<Vec<u8>, ReplayEntry>,
|
cache: LruCache<Box<[u8]>, ReplayEntry>,
|
||||||
queue: VecDeque<(Instant, Vec<u8>)>,
|
queue: VecDeque<(Instant, Box<[u8]>, u64)>,
|
||||||
|
seq_counter: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ReplayShard {
|
impl ReplayShard {
|
||||||
@@ -166,33 +132,60 @@ impl ReplayShard {
|
|||||||
Self {
|
Self {
|
||||||
cache: LruCache::new(cap),
|
cache: LruCache::new(cap),
|
||||||
queue: VecDeque::with_capacity(cap.get()),
|
queue: VecDeque::with_capacity(cap.get()),
|
||||||
|
seq_counter: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn next_seq(&mut self) -> u64 {
|
||||||
|
self.seq_counter += 1;
|
||||||
|
self.seq_counter
|
||||||
|
}
|
||||||
|
|
||||||
fn cleanup(&mut self, now: Instant, window: Duration) {
|
fn cleanup(&mut self, now: Instant, window: Duration) {
|
||||||
if window.is_zero() {
|
if window.is_zero() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
let cutoff = now - window;
|
let cutoff = now.checked_sub(window).unwrap_or(now);
|
||||||
while let Some((ts, _)) = self.queue.front() {
|
|
||||||
|
while let Some((ts, _, _)) = self.queue.front() {
|
||||||
if *ts >= cutoff {
|
if *ts >= cutoff {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
let (ts_old, key_old) = self.queue.pop_front().unwrap();
|
let (_, key, queue_seq) = self.queue.pop_front().unwrap();
|
||||||
if let Some(entry) = self.cache.get(&key_old) {
|
|
||||||
if entry.seen_at <= ts_old {
|
// Use key.as_ref() to get &[u8] — avoids Borrow<Q> ambiguity
|
||||||
self.cache.pop(&key_old);
|
// between Borrow<[u8]> and Borrow<Box<[u8]>>
|
||||||
}
|
if let Some(entry) = self.cache.peek(key.as_ref()) {
|
||||||
|
if entry.seq == queue_seq {
|
||||||
|
self.cache.pop(key.as_ref());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn check(&mut self, key: &[u8], now: Instant, window: Duration) -> bool {
|
||||||
|
self.cleanup(now, window);
|
||||||
|
// key is &[u8], resolves Q=[u8] via Box<[u8]>: Borrow<[u8]>
|
||||||
|
self.cache.get(key).is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add(&mut self, key: &[u8], now: Instant, window: Duration) {
|
||||||
|
self.cleanup(now, window);
|
||||||
|
|
||||||
|
let seq = self.next_seq();
|
||||||
|
let boxed_key: Box<[u8]> = key.into();
|
||||||
|
|
||||||
|
self.cache.put(boxed_key.clone(), ReplayEntry { seen_at: now, seq });
|
||||||
|
self.queue.push_back((now, boxed_key, seq));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn len(&self) -> usize {
|
||||||
|
self.cache.len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ReplayChecker {
|
impl ReplayChecker {
|
||||||
/// Create new replay checker with specified capacity per shard
|
|
||||||
/// Total capacity = capacity * num_shards
|
|
||||||
pub fn new(total_capacity: usize, window: Duration) -> Self {
|
pub fn new(total_capacity: usize, window: Duration) -> Self {
|
||||||
// Use 64 shards for good concurrency
|
|
||||||
let num_shards = 64;
|
let num_shards = 64;
|
||||||
let shard_capacity = (total_capacity / num_shards).max(1);
|
let shard_capacity = (total_capacity / num_shards).max(1);
|
||||||
let cap = NonZeroUsize::new(shard_capacity).unwrap();
|
let cap = NonZeroUsize::new(shard_capacity).unwrap();
|
||||||
@@ -206,50 +199,114 @@ impl ReplayChecker {
|
|||||||
shards,
|
shards,
|
||||||
shard_mask: num_shards - 1,
|
shard_mask: num_shards - 1,
|
||||||
window,
|
window,
|
||||||
|
checks: AtomicU64::new(0),
|
||||||
|
hits: AtomicU64::new(0),
|
||||||
|
additions: AtomicU64::new(0),
|
||||||
|
cleanups: AtomicU64::new(0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_shard(&self, key: &[u8]) -> usize {
|
fn get_shard_idx(&self, key: &[u8]) -> usize {
|
||||||
let mut hasher = DefaultHasher::new();
|
let mut hasher = DefaultHasher::new();
|
||||||
key.hash(&mut hasher);
|
key.hash(&mut hasher);
|
||||||
(hasher.finish() as usize) & self.shard_mask
|
(hasher.finish() as usize) & self.shard_mask
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check(&self, data: &[u8]) -> bool {
|
fn check(&self, data: &[u8]) -> bool {
|
||||||
let shard_idx = self.get_shard(data);
|
self.checks.fetch_add(1, Ordering::Relaxed);
|
||||||
let mut shard = self.shards[shard_idx].lock();
|
let idx = self.get_shard_idx(data);
|
||||||
let now = Instant::now();
|
let mut shard = self.shards[idx].lock();
|
||||||
shard.cleanup(now, self.window);
|
let found = shard.check(data, Instant::now(), self.window);
|
||||||
|
if found {
|
||||||
let key = data.to_vec();
|
self.hits.fetch_add(1, Ordering::Relaxed);
|
||||||
shard.cache.get(&key).is_some()
|
}
|
||||||
|
found
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add(&self, data: &[u8]) {
|
fn add(&self, data: &[u8]) {
|
||||||
let shard_idx = self.get_shard(data);
|
self.additions.fetch_add(1, Ordering::Relaxed);
|
||||||
let mut shard = self.shards[shard_idx].lock();
|
let idx = self.get_shard_idx(data);
|
||||||
|
let mut shard = self.shards[idx].lock();
|
||||||
|
shard.add(data, Instant::now(), self.window);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check_handshake(&self, data: &[u8]) -> bool { self.check(data) }
|
||||||
|
pub fn add_handshake(&self, data: &[u8]) { self.add(data) }
|
||||||
|
pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check(data) }
|
||||||
|
pub fn add_tls_digest(&self, data: &[u8]) { self.add(data) }
|
||||||
|
|
||||||
|
pub fn stats(&self) -> ReplayStats {
|
||||||
|
let mut total_entries = 0;
|
||||||
|
let mut total_queue_len = 0;
|
||||||
|
for shard in &self.shards {
|
||||||
|
let s = shard.lock();
|
||||||
|
total_entries += s.cache.len();
|
||||||
|
total_queue_len += s.queue.len();
|
||||||
|
}
|
||||||
|
|
||||||
|
ReplayStats {
|
||||||
|
total_entries,
|
||||||
|
total_queue_len,
|
||||||
|
total_checks: self.checks.load(Ordering::Relaxed),
|
||||||
|
total_hits: self.hits.load(Ordering::Relaxed),
|
||||||
|
total_additions: self.additions.load(Ordering::Relaxed),
|
||||||
|
total_cleanups: self.cleanups.load(Ordering::Relaxed),
|
||||||
|
num_shards: self.shards.len(),
|
||||||
|
window_secs: self.window.as_secs(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_periodic_cleanup(&self) {
|
||||||
|
let interval = if self.window.as_secs() > 60 {
|
||||||
|
Duration::from_secs(30)
|
||||||
|
} else {
|
||||||
|
Duration::from_secs(self.window.as_secs().max(1) / 2)
|
||||||
|
};
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::time::sleep(interval).await;
|
||||||
|
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
|
let mut cleaned = 0usize;
|
||||||
|
|
||||||
|
for shard_mutex in &self.shards {
|
||||||
|
let mut shard = shard_mutex.lock();
|
||||||
|
let before = shard.len();
|
||||||
shard.cleanup(now, self.window);
|
shard.cleanup(now, self.window);
|
||||||
|
let after = shard.len();
|
||||||
let key = data.to_vec();
|
cleaned += before.saturating_sub(after);
|
||||||
shard.cache.put(key.clone(), ReplayEntry { seen_at: now });
|
|
||||||
shard.queue.push_back((now, key));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn check_handshake(&self, data: &[u8]) -> bool {
|
self.cleanups.fetch_add(1, Ordering::Relaxed);
|
||||||
self.check(data)
|
|
||||||
|
if cleaned > 0 {
|
||||||
|
debug!(cleaned = cleaned, "Replay checker: periodic cleanup");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_handshake(&self, data: &[u8]) {
|
#[derive(Debug, Clone)]
|
||||||
self.add(data)
|
pub struct ReplayStats {
|
||||||
|
pub total_entries: usize,
|
||||||
|
pub total_queue_len: usize,
|
||||||
|
pub total_checks: u64,
|
||||||
|
pub total_hits: u64,
|
||||||
|
pub total_additions: u64,
|
||||||
|
pub total_cleanups: u64,
|
||||||
|
pub num_shards: usize,
|
||||||
|
pub window_secs: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn check_tls_digest(&self, data: &[u8]) -> bool {
|
impl ReplayStats {
|
||||||
self.check(data)
|
pub fn hit_rate(&self) -> f64 {
|
||||||
|
if self.total_checks == 0 { 0.0 }
|
||||||
|
else { (self.total_hits as f64 / self.total_checks as f64) * 100.0 }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_tls_digest(&self, data: &[u8]) {
|
pub fn ghost_ratio(&self) -> f64 {
|
||||||
self.add(data)
|
if self.total_entries == 0 { 0.0 }
|
||||||
|
else { self.total_queue_len as f64 / self.total_entries as f64 }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -260,28 +317,60 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_stats_shared_counters() {
|
fn test_stats_shared_counters() {
|
||||||
let stats = Arc::new(Stats::new());
|
let stats = Arc::new(Stats::new());
|
||||||
|
stats.increment_connects_all();
|
||||||
let stats1 = Arc::clone(&stats);
|
stats.increment_connects_all();
|
||||||
let stats2 = Arc::clone(&stats);
|
stats.increment_connects_all();
|
||||||
|
|
||||||
stats1.increment_connects_all();
|
|
||||||
stats2.increment_connects_all();
|
|
||||||
stats1.increment_connects_all();
|
|
||||||
|
|
||||||
assert_eq!(stats.get_connects_all(), 3);
|
assert_eq!(stats.get_connects_all(), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_replay_checker_sharding() {
|
fn test_replay_checker_basic() {
|
||||||
let checker = ReplayChecker::new(100, Duration::from_secs(60));
|
let checker = ReplayChecker::new(100, Duration::from_secs(60));
|
||||||
let data1 = b"test1";
|
assert!(!checker.check_handshake(b"test1"));
|
||||||
let data2 = b"test2";
|
checker.add_handshake(b"test1");
|
||||||
|
assert!(checker.check_handshake(b"test1"));
|
||||||
|
assert!(!checker.check_handshake(b"test2"));
|
||||||
|
}
|
||||||
|
|
||||||
checker.add_handshake(data1);
|
#[test]
|
||||||
assert!(checker.check_handshake(data1));
|
fn test_replay_checker_duplicate_add() {
|
||||||
assert!(!checker.check_handshake(data2));
|
let checker = ReplayChecker::new(100, Duration::from_secs(60));
|
||||||
|
checker.add_handshake(b"dup");
|
||||||
|
checker.add_handshake(b"dup");
|
||||||
|
assert!(checker.check_handshake(b"dup"));
|
||||||
|
}
|
||||||
|
|
||||||
checker.add_handshake(data2);
|
#[test]
|
||||||
assert!(checker.check_handshake(data2));
|
fn test_replay_checker_expiration() {
|
||||||
|
let checker = ReplayChecker::new(100, Duration::from_millis(50));
|
||||||
|
checker.add_handshake(b"expire");
|
||||||
|
assert!(checker.check_handshake(b"expire"));
|
||||||
|
std::thread::sleep(Duration::from_millis(100));
|
||||||
|
assert!(!checker.check_handshake(b"expire"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_replay_checker_stats() {
|
||||||
|
let checker = ReplayChecker::new(100, Duration::from_secs(60));
|
||||||
|
checker.add_handshake(b"k1");
|
||||||
|
checker.add_handshake(b"k2");
|
||||||
|
checker.check_handshake(b"k1");
|
||||||
|
checker.check_handshake(b"k3");
|
||||||
|
let stats = checker.stats();
|
||||||
|
assert_eq!(stats.total_additions, 2);
|
||||||
|
assert_eq!(stats.total_checks, 2);
|
||||||
|
assert_eq!(stats.total_hits, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_replay_checker_many_keys() {
|
||||||
|
let checker = ReplayChecker::new(1000, Duration::from_secs(60));
|
||||||
|
for i in 0..500u32 {
|
||||||
|
checker.add(&i.to_le_bytes());
|
||||||
|
}
|
||||||
|
for i in 0..500u32 {
|
||||||
|
assert!(checker.check(&i.to_le_bytes()));
|
||||||
|
}
|
||||||
|
assert_eq!(checker.stats().total_entries, 500);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -10,4 +10,4 @@ pub use pool::ConnectionPool;
|
|||||||
pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol};
|
pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol};
|
||||||
pub use socket::*;
|
pub use socket::*;
|
||||||
pub use socks::*;
|
pub use socks::*;
|
||||||
pub use upstream::UpstreamManager;
|
pub use upstream::{UpstreamManager, StartupPingResult, DcPingResult};
|
||||||
@@ -1,26 +1,78 @@
|
|||||||
//! Upstream Management
|
//! Upstream Management with RTT tracking and startup ping
|
||||||
|
|
||||||
use std::net::{SocketAddr, IpAddr};
|
use std::net::{SocketAddr, IpAddr};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
|
use tokio::time::Instant;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use tracing::{debug, warn, error, info};
|
use tracing::{debug, warn, error, info};
|
||||||
|
|
||||||
use crate::config::{UpstreamConfig, UpstreamType};
|
use crate::config::{UpstreamConfig, UpstreamType};
|
||||||
use crate::error::{Result, ProxyError};
|
use crate::error::{Result, ProxyError};
|
||||||
|
use crate::protocol::constants::{TG_DATACENTERS_V4, TG_DATACENTERS_V6, TG_DATACENTER_PORT};
|
||||||
use crate::transport::socket::create_outgoing_socket_bound;
|
use crate::transport::socket::create_outgoing_socket_bound;
|
||||||
use crate::transport::socks::{connect_socks4, connect_socks5};
|
use crate::transport::socks::{connect_socks4, connect_socks5};
|
||||||
|
|
||||||
|
// ============= RTT Tracking =============
|
||||||
|
|
||||||
|
/// Exponential moving average for latency tracking
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct LatencyEma {
|
||||||
|
/// Current EMA value in milliseconds (None = no data yet)
|
||||||
|
value_ms: Option<f64>,
|
||||||
|
/// Smoothing factor (0.0 - 1.0, higher = more weight to recent)
|
||||||
|
alpha: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LatencyEma {
|
||||||
|
fn new(alpha: f64) -> Self {
|
||||||
|
Self { value_ms: None, alpha }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&mut self, sample_ms: f64) {
|
||||||
|
self.value_ms = Some(match self.value_ms {
|
||||||
|
None => sample_ms,
|
||||||
|
Some(prev) => prev * (1.0 - self.alpha) + sample_ms * self.alpha,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get(&self) -> Option<f64> {
|
||||||
|
self.value_ms
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Upstream State =============
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct UpstreamState {
|
struct UpstreamState {
|
||||||
config: UpstreamConfig,
|
config: UpstreamConfig,
|
||||||
healthy: bool,
|
healthy: bool,
|
||||||
fails: u32,
|
fails: u32,
|
||||||
last_check: std::time::Instant,
|
last_check: std::time::Instant,
|
||||||
|
/// Latency EMA (alpha=0.3 — moderate smoothing)
|
||||||
|
latency: LatencyEma,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Result of a single DC ping
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct DcPingResult {
|
||||||
|
pub dc_idx: usize,
|
||||||
|
pub dc_addr: SocketAddr,
|
||||||
|
pub rtt_ms: Option<f64>,
|
||||||
|
pub error: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of startup ping across all DCs
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct StartupPingResult {
|
||||||
|
pub results: Vec<DcPingResult>,
|
||||||
|
pub upstream_name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Upstream Manager =============
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct UpstreamManager {
|
pub struct UpstreamManager {
|
||||||
upstreams: Arc<RwLock<Vec<UpstreamState>>>,
|
upstreams: Arc<RwLock<Vec<UpstreamState>>>,
|
||||||
@@ -35,6 +87,7 @@ impl UpstreamManager {
|
|||||||
healthy: true,
|
healthy: true,
|
||||||
fails: 0,
|
fails: 0,
|
||||||
last_check: std::time::Instant::now(),
|
last_check: std::time::Instant::now(),
|
||||||
|
latency: LatencyEma::new(0.3),
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
@@ -43,7 +96,7 @@ impl UpstreamManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Select an upstream using Weighted Round Robin (simplified)
|
/// Select an upstream using weighted selection among healthy upstreams
|
||||||
async fn select_upstream(&self) -> Option<usize> {
|
async fn select_upstream(&self) -> Option<usize> {
|
||||||
let upstreams = self.upstreams.read().await;
|
let upstreams = self.upstreams.read().await;
|
||||||
if upstreams.is_empty() {
|
if upstreams.is_empty() {
|
||||||
@@ -57,11 +110,9 @@ impl UpstreamManager {
|
|||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
if healthy_indices.is_empty() {
|
if healthy_indices.is_empty() {
|
||||||
// If all unhealthy, try any random one
|
|
||||||
return Some(rand::rng().gen_range(0..upstreams.len()));
|
return Some(rand::rng().gen_range(0..upstreams.len()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Weighted selection
|
|
||||||
let total_weight: u32 = healthy_indices.iter()
|
let total_weight: u32 = healthy_indices.iter()
|
||||||
.map(|&i| upstreams[i].config.weight as u32)
|
.map(|&i| upstreams[i].config.weight as u32)
|
||||||
.sum();
|
.sum();
|
||||||
@@ -92,15 +143,19 @@ impl UpstreamManager {
|
|||||||
guard[idx].config.clone()
|
guard[idx].config.clone()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let start = Instant::now();
|
||||||
|
|
||||||
match self.connect_via_upstream(&upstream, target).await {
|
match self.connect_via_upstream(&upstream, target).await {
|
||||||
Ok(stream) => {
|
Ok(stream) => {
|
||||||
|
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
|
||||||
let mut guard = self.upstreams.write().await;
|
let mut guard = self.upstreams.write().await;
|
||||||
if let Some(u) = guard.get_mut(idx) {
|
if let Some(u) = guard.get_mut(idx) {
|
||||||
if !u.healthy {
|
if !u.healthy {
|
||||||
debug!("Upstream recovered: {:?}", u.config);
|
debug!(rtt_ms = rtt_ms, "Upstream recovered: {:?}", u.config);
|
||||||
}
|
}
|
||||||
u.healthy = true;
|
u.healthy = true;
|
||||||
u.fails = 0;
|
u.fails = 0;
|
||||||
|
u.latency.update(rtt_ms);
|
||||||
}
|
}
|
||||||
Ok(stream)
|
Ok(stream)
|
||||||
},
|
},
|
||||||
@@ -108,10 +163,10 @@ impl UpstreamManager {
|
|||||||
let mut guard = self.upstreams.write().await;
|
let mut guard = self.upstreams.write().await;
|
||||||
if let Some(u) = guard.get_mut(idx) {
|
if let Some(u) = guard.get_mut(idx) {
|
||||||
u.fails += 1;
|
u.fails += 1;
|
||||||
warn!("Failed to connect via upstream {:?}: {}. Fails: {}", u.config, e, u.fails);
|
warn!("Upstream {:?} failed: {}. Consecutive fails: {}", u.config, e, u.fails);
|
||||||
if u.fails > 3 {
|
if u.fails > 3 {
|
||||||
u.healthy = false;
|
u.healthy = false;
|
||||||
warn!("Upstream disabled due to failures: {:?}", u.config);
|
warn!("Upstream marked unhealthy: {:?}", u.config);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e)
|
Err(e)
|
||||||
@@ -145,7 +200,7 @@ impl UpstreamManager {
|
|||||||
Ok(stream)
|
Ok(stream)
|
||||||
},
|
},
|
||||||
UpstreamType::Socks4 { address, interface, user_id } => {
|
UpstreamType::Socks4 { address, interface, user_id } => {
|
||||||
info!("Connecting to target {} via SOCKS4 proxy {}", target, address);
|
info!("Connecting to {} via SOCKS4 {}", target, address);
|
||||||
|
|
||||||
let proxy_addr: SocketAddr = address.parse()
|
let proxy_addr: SocketAddr = address.parse()
|
||||||
.map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?;
|
.map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?;
|
||||||
@@ -174,7 +229,7 @@ impl UpstreamManager {
|
|||||||
Ok(stream)
|
Ok(stream)
|
||||||
},
|
},
|
||||||
UpstreamType::Socks5 { address, interface, username, password } => {
|
UpstreamType::Socks5 { address, interface, username, password } => {
|
||||||
info!("Connecting to target {} via SOCKS5 proxy {}", target, address);
|
info!("Connecting to {} via SOCKS5 {}", target, address);
|
||||||
|
|
||||||
let proxy_addr: SocketAddr = address.parse()
|
let proxy_addr: SocketAddr = address.parse()
|
||||||
.map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?;
|
.map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?;
|
||||||
@@ -205,12 +260,109 @@ impl UpstreamManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Background task to check health
|
// ============= Startup Ping =============
|
||||||
pub async fn run_health_checks(&self) {
|
|
||||||
let check_target: SocketAddr = "149.154.167.50:443".parse().unwrap();
|
/// Ping all Telegram DCs through all upstreams and return results.
|
||||||
|
///
|
||||||
|
/// Used at startup to display connectivity and latency info.
|
||||||
|
pub async fn ping_all_dcs(&self, prefer_ipv6: bool) -> Vec<StartupPingResult> {
|
||||||
|
let upstreams: Vec<(usize, UpstreamConfig)> = {
|
||||||
|
let guard = self.upstreams.read().await;
|
||||||
|
guard.iter().enumerate()
|
||||||
|
.map(|(i, u)| (i, u.config.clone()))
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 };
|
||||||
|
|
||||||
|
let mut all_results = Vec::new();
|
||||||
|
|
||||||
|
for (upstream_idx, upstream_config) in &upstreams {
|
||||||
|
let upstream_name = match &upstream_config.upstream_type {
|
||||||
|
UpstreamType::Direct { interface } => {
|
||||||
|
format!("direct{}", interface.as_ref().map(|i| format!(" ({})", i)).unwrap_or_default())
|
||||||
|
}
|
||||||
|
UpstreamType::Socks4 { address, .. } => format!("socks4://{}", address),
|
||||||
|
UpstreamType::Socks5 { address, .. } => format!("socks5://{}", address),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut dc_results = Vec::new();
|
||||||
|
|
||||||
|
for (dc_zero_idx, dc_ip) in datacenters.iter().enumerate() {
|
||||||
|
let dc_addr = SocketAddr::new(*dc_ip, TG_DATACENTER_PORT);
|
||||||
|
|
||||||
|
let ping_result = tokio::time::timeout(
|
||||||
|
Duration::from_secs(5),
|
||||||
|
self.ping_single_dc(upstream_config, dc_addr)
|
||||||
|
).await;
|
||||||
|
|
||||||
|
let result = match ping_result {
|
||||||
|
Ok(Ok(rtt_ms)) => {
|
||||||
|
// Update latency EMA
|
||||||
|
let mut guard = self.upstreams.write().await;
|
||||||
|
if let Some(u) = guard.get_mut(*upstream_idx) {
|
||||||
|
u.latency.update(rtt_ms);
|
||||||
|
}
|
||||||
|
DcPingResult {
|
||||||
|
dc_idx: dc_zero_idx + 1,
|
||||||
|
dc_addr,
|
||||||
|
rtt_ms: Some(rtt_ms),
|
||||||
|
error: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(Err(e)) => DcPingResult {
|
||||||
|
dc_idx: dc_zero_idx + 1,
|
||||||
|
dc_addr,
|
||||||
|
rtt_ms: None,
|
||||||
|
error: Some(e.to_string()),
|
||||||
|
},
|
||||||
|
Err(_) => DcPingResult {
|
||||||
|
dc_idx: dc_zero_idx + 1,
|
||||||
|
dc_addr,
|
||||||
|
rtt_ms: None,
|
||||||
|
error: Some("timeout (5s)".to_string()),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
dc_results.push(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
all_results.push(StartupPingResult {
|
||||||
|
results: dc_results,
|
||||||
|
upstream_name,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
all_results
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Ping a single DC: TCP connect, measure RTT, then drop.
|
||||||
|
async fn ping_single_dc(&self, config: &UpstreamConfig, target: SocketAddr) -> Result<f64> {
|
||||||
|
let start = Instant::now();
|
||||||
|
let _stream = self.connect_via_upstream(config, target).await?;
|
||||||
|
let rtt = start.elapsed();
|
||||||
|
Ok(rtt.as_secs_f64() * 1000.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Health Checks =============
|
||||||
|
|
||||||
|
/// Background health check task.
|
||||||
|
///
|
||||||
|
/// Every 30 seconds, pings one representative DC per upstream.
|
||||||
|
/// Measures RTT and updates health status.
|
||||||
|
pub async fn run_health_checks(&self, prefer_ipv6: bool) {
|
||||||
|
let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 };
|
||||||
|
|
||||||
|
// Rotate through DCs across check cycles
|
||||||
|
let mut dc_rotation = 0usize;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
tokio::time::sleep(Duration::from_secs(60)).await;
|
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||||
|
|
||||||
|
let check_dc_idx = dc_rotation % datacenters.len();
|
||||||
|
dc_rotation += 1;
|
||||||
|
|
||||||
|
let check_target = SocketAddr::new(datacenters[check_dc_idx], TG_DATACENTER_PORT);
|
||||||
|
|
||||||
let count = self.upstreams.read().await.len();
|
let count = self.upstreams.read().await.len();
|
||||||
for i in 0..count {
|
for i in 0..count {
|
||||||
@@ -219,6 +371,7 @@ impl UpstreamManager {
|
|||||||
guard[i].config.clone()
|
guard[i].config.clone()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let start = Instant::now();
|
||||||
let result = tokio::time::timeout(
|
let result = tokio::time::timeout(
|
||||||
Duration::from_secs(10),
|
Duration::from_secs(10),
|
||||||
self.connect_via_upstream(&config, check_target)
|
self.connect_via_upstream(&config, check_target)
|
||||||
@@ -229,17 +382,42 @@ impl UpstreamManager {
|
|||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(Ok(_stream)) => {
|
Ok(Ok(_stream)) => {
|
||||||
|
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
|
||||||
|
u.latency.update(rtt_ms);
|
||||||
|
|
||||||
if !u.healthy {
|
if !u.healthy {
|
||||||
debug!("Upstream recovered: {:?}", u.config);
|
info!(
|
||||||
|
rtt_ms = format!("{:.1}", rtt_ms),
|
||||||
|
dc = check_dc_idx + 1,
|
||||||
|
"Upstream recovered: {:?}", u.config
|
||||||
|
);
|
||||||
}
|
}
|
||||||
u.healthy = true;
|
u.healthy = true;
|
||||||
u.fails = 0;
|
u.fails = 0;
|
||||||
}
|
}
|
||||||
Ok(Err(e)) => {
|
Ok(Err(e)) => {
|
||||||
debug!("Health check failed for {:?}: {}", u.config, e);
|
u.fails += 1;
|
||||||
|
debug!(
|
||||||
|
dc = check_dc_idx + 1,
|
||||||
|
fails = u.fails,
|
||||||
|
"Health check failed for {:?}: {}", u.config, e
|
||||||
|
);
|
||||||
|
if u.fails > 3 {
|
||||||
|
u.healthy = false;
|
||||||
|
warn!("Upstream unhealthy (health check): {:?}", u.config);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
debug!("Health check timeout for {:?}", u.config);
|
u.fails += 1;
|
||||||
|
debug!(
|
||||||
|
dc = check_dc_idx + 1,
|
||||||
|
fails = u.fails,
|
||||||
|
"Health check timeout for {:?}", u.config
|
||||||
|
);
|
||||||
|
if u.fails > 3 {
|
||||||
|
u.healthy = false;
|
||||||
|
warn!("Upstream unhealthy (timeout): {:?}", u.config);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
u.last_check = std::time::Instant::now();
|
u.last_check = std::time::Instant::now();
|
||||||
|
|||||||
Reference in New Issue
Block a user