IP Version Superfallback

This commit is contained in:
Alexey
2026-02-14 00:26:07 +03:00
parent de28655dd2
commit 9b850b0bfb
4 changed files with 775 additions and 581 deletions

View File

@@ -1,391 +1,434 @@
//! telemt — Telegram MTProto Proxy //! telemt — Telegram MTProto Proxy
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::signal; use tokio::signal;
use tokio::sync::Semaphore; use tokio::sync::Semaphore;
use tracing::{info, error, warn, debug}; use tracing::{info, error, warn, debug};
use tracing_subscriber::{fmt, EnvFilter, reload, prelude::*}; use tracing_subscriber::{fmt, EnvFilter, reload, prelude::*};
mod cli; mod cli;
mod config; mod config;
mod crypto; mod crypto;
mod error; mod error;
mod protocol; mod protocol;
mod proxy; mod proxy;
mod stats; mod stats;
mod stream; mod stream;
mod transport; mod transport;
mod util; mod util;
use crate::config::{ProxyConfig, LogLevel}; use crate::config::{ProxyConfig, LogLevel};
use crate::proxy::ClientHandler; use crate::proxy::ClientHandler;
use crate::stats::{Stats, ReplayChecker}; use crate::stats::{Stats, ReplayChecker};
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::transport::{create_listener, ListenOptions, UpstreamManager}; use crate::transport::{create_listener, ListenOptions, UpstreamManager};
use crate::transport::middle_proxy::MePool; use crate::transport::middle_proxy::MePool;
use crate::util::ip::detect_ip; use crate::util::ip::detect_ip;
use crate::stream::BufferPool; use crate::stream::BufferPool;
fn parse_cli() -> (String, bool, Option<String>) { fn parse_cli() -> (String, bool, Option<String>) {
let mut config_path = "config.toml".to_string(); let mut config_path = "config.toml".to_string();
let mut silent = false; let mut silent = false;
let mut log_level: Option<String> = None; let mut log_level: Option<String> = None;
let args: Vec<String> = std::env::args().skip(1).collect(); let args: Vec<String> = std::env::args().skip(1).collect();
// Check for --init first (handled before tokio) // Check for --init first (handled before tokio)
if let Some(init_opts) = cli::parse_init_args(&args) { if let Some(init_opts) = cli::parse_init_args(&args) {
if let Err(e) = cli::run_init(init_opts) { if let Err(e) = cli::run_init(init_opts) {
eprintln!("[telemt] Init failed: {}", e); eprintln!("[telemt] Init failed: {}", e);
std::process::exit(1); std::process::exit(1);
}
std::process::exit(0);
} }
std::process::exit(0);
let mut i = 0;
while i < args.len() {
match args[i].as_str() {
"--silent" | "-s" => { silent = true; }
"--log-level" => {
i += 1;
if i < args.len() { log_level = Some(args[i].clone()); }
}
s if s.starts_with("--log-level=") => {
log_level = Some(s.trim_start_matches("--log-level=").to_string());
}
"--help" | "-h" => {
eprintln!("Usage: telemt [config.toml] [OPTIONS]");
eprintln!();
eprintln!("Options:");
eprintln!(" --silent, -s Suppress info logs");
eprintln!(" --log-level <LEVEL> debug|verbose|normal|silent");
eprintln!(" --help, -h Show this help");
eprintln!();
eprintln!("Setup (fire-and-forget):");
eprintln!(" --init Generate config, install systemd service, start");
eprintln!(" --port <PORT> Listen port (default: 443)");
eprintln!(" --domain <DOMAIN> TLS domain for masking (default: www.google.com)");
eprintln!(" --secret <HEX> 32-char hex secret (auto-generated if omitted)");
eprintln!(" --user <NAME> Username (default: user)");
eprintln!(" --config-dir <DIR> Config directory (default: /etc/telemt)");
eprintln!(" --no-start Don't start the service after install");
std::process::exit(0);
}
s if !s.starts_with('-') => { config_path = s.to_string(); }
other => { eprintln!("Unknown option: {}", other); }
}
i += 1;
}
(config_path, silent, log_level)
} }
#[tokio::main] let mut i = 0;
async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> { while i < args.len() {
let (config_path, cli_silent, cli_log_level) = parse_cli(); match args[i].as_str() {
"--silent" | "-s" => { silent = true; }
let config = match ProxyConfig::load(&config_path) { "--log-level" => {
Ok(c) => c, i += 1;
Err(e) => { if i < args.len() { log_level = Some(args[i].clone()); }
if std::path::Path::new(&config_path).exists() {
eprintln!("[telemt] Error: {}", e);
std::process::exit(1);
} else {
let default = ProxyConfig::default();
std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap();
eprintln!("[telemt] Created default config at {}", config_path);
default
}
} }
}; s if s.starts_with("--log-level=") => {
log_level = Some(s.trim_start_matches("--log-level=").to_string());
if let Err(e) = config.validate() {
eprintln!("[telemt] Invalid config: {}", e);
std::process::exit(1);
}
let has_rust_log = std::env::var("RUST_LOG").is_ok();
let effective_log_level = if cli_silent {
LogLevel::Silent
} else if let Some(ref s) = cli_log_level {
LogLevel::from_str_loose(s)
} else {
config.general.log_level.clone()
};
let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new("info"));
tracing_subscriber::registry()
.with(filter_layer)
.with(fmt::Layer::default())
.init();
info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION"));
info!("Log level: {}", effective_log_level);
info!("Modes: classic={} secure={} tls={}",
config.general.modes.classic,
config.general.modes.secure,
config.general.modes.tls);
info!("TLS domain: {}", config.censorship.tls_domain);
if let Some(ref sock) = config.censorship.mask_unix_sock {
info!("Mask: {} -> unix:{}", config.censorship.mask, sock);
if !std::path::Path::new(sock).exists() {
warn!("Unix socket '{}' does not exist yet. Masking will fail until it appears.", sock);
} }
} else { "--help" | "-h" => {
info!("Mask: {} -> {}:{}", eprintln!("Usage: telemt [config.toml] [OPTIONS]");
config.censorship.mask, eprintln!();
config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain), eprintln!("Options:");
config.censorship.mask_port); eprintln!(" --silent, -s Suppress info logs");
eprintln!(" --log-level <LEVEL> debug|verbose|normal|silent");
eprintln!(" --help, -h Show this help");
eprintln!();
eprintln!("Setup (fire-and-forget):");
eprintln!(" --init Generate config, install systemd service, start");
eprintln!(" --port <PORT> Listen port (default: 443)");
eprintln!(" --domain <DOMAIN> TLS domain for masking (default: www.google.com)");
eprintln!(" --secret <HEX> 32-char hex secret (auto-generated if omitted)");
eprintln!(" --user <NAME> Username (default: user)");
eprintln!(" --config-dir <DIR> Config directory (default: /etc/telemt)");
eprintln!(" --no-start Don't start the service after install");
std::process::exit(0);
}
s if !s.starts_with('-') => { config_path = s.to_string(); }
other => { eprintln!("Unknown option: {}", other); }
} }
i += 1;
if config.censorship.tls_domain == "www.google.com" { }
warn!("Using default tls_domain. Consider setting a custom domain.");
} (config_path, silent, log_level)
}
let prefer_ipv6 = config.general.prefer_ipv6;
let use_middle_proxy = config.general.use_middle_proxy; #[tokio::main]
let config = Arc::new(config); async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let stats = Arc::new(Stats::new()); let (config_path, cli_silent, cli_log_level) = parse_cli();
let rng = Arc::new(SecureRandom::new());
let config = match ProxyConfig::load(&config_path) {
let replay_checker = Arc::new(ReplayChecker::new( Ok(c) => c,
config.access.replay_check_len, Err(e) => {
Duration::from_secs(config.access.replay_window_secs), if std::path::Path::new(&config_path).exists() {
)); eprintln!("[telemt] Error: {}", e);
std::process::exit(1);
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096));
// Connection concurrency limit
let _max_connections = Arc::new(Semaphore::new(10_000));
// =====================================================================
// Middle Proxy initialization (if enabled)
// =====================================================================
let me_pool: Option<Arc<MePool>> = if use_middle_proxy {
info!("=== Middle Proxy Mode ===");
// ad_tag (proxy_tag) for advertising
let proxy_tag = config.general.ad_tag.as_ref().map(|tag| {
hex::decode(tag).unwrap_or_else(|_| {
warn!("Invalid ad_tag hex, middle proxy ad_tag will be empty");
Vec::new()
})
});
// =============================================================
// CRITICAL: Download Telegram proxy-secret (NOT user secret!)
//
// C MTProxy uses TWO separate secrets:
// -S flag = 16-byte user secret for client obfuscation
// --aes-pwd = 32-512 byte binary file for ME RPC auth
//
// proxy-secret is from: https://core.telegram.org/getProxySecret
// =============================================================
let proxy_secret_path = config.general.proxy_secret_path.as_deref();
match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).await {
Ok(proxy_secret) => {
info!(
secret_len = proxy_secret.len(),
key_sig = format_args!("0x{:08x}",
if proxy_secret.len() >= 4 {
u32::from_le_bytes([proxy_secret[0], proxy_secret[1],
proxy_secret[2], proxy_secret[3]])
} else { 0 }),
"Proxy-secret loaded"
);
let pool = MePool::new(proxy_tag, proxy_secret);
match pool.init(2, &rng).await {
Ok(()) => {
info!("Middle-End pool initialized successfully");
// Phase 4: Start health monitor
let pool_clone = pool.clone();
let rng_clone = rng.clone();
tokio::spawn(async move {
crate::transport::middle_proxy::me_health_monitor(
pool_clone, rng_clone, 2,
).await;
});
Some(pool)
}
Err(e) => {
error!(error = %e, "Failed to initialize ME pool. Falling back to direct mode.");
None
}
}
}
Err(e) => {
error!(error = %e, "Failed to fetch proxy-secret. Falling back to direct mode.");
None
}
}
} else { } else {
None let default = ProxyConfig::default();
}; std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap();
eprintln!("[telemt] Created default config at {}", config_path);
if me_pool.is_some() { default
info!("Transport: Middle Proxy (supports all DCs including CDN)");
} else {
info!("Transport: Direct TCP (standard DCs only)");
}
// Startup DC ping (only meaningful in direct mode)
if me_pool.is_none() {
info!("=== Telegram DC Connectivity ===");
let ping_results = upstream_manager.ping_all_dcs(prefer_ipv6).await;
for upstream_result in &ping_results {
info!(" via {}", upstream_result.upstream_name);
for dc in &upstream_result.results {
match (&dc.rtt_ms, &dc.error) {
(Some(rtt), _) => {
info!(" DC{} ({:>21}): {:.0}ms", dc.dc_idx, dc.dc_addr, rtt);
}
(None, Some(err)) => {
info!(" DC{} ({:>21}): FAIL ({})", dc.dc_idx, dc.dc_addr, err);
}
_ => {
info!(" DC{} ({:>21}): FAIL", dc.dc_idx, dc.dc_addr);
}
}
}
}
info!("================================");
}
// Background tasks
let um_clone = upstream_manager.clone();
tokio::spawn(async move { um_clone.run_health_checks(prefer_ipv6).await; });
let rc_clone = replay_checker.clone();
tokio::spawn(async move { rc_clone.run_periodic_cleanup().await; });
let detected_ip = detect_ip().await;
debug!("Detected IPs: v4={:?} v6={:?}", detected_ip.ipv4, detected_ip.ipv6);
let mut listeners = Vec::new();
for listener_conf in &config.server.listeners {
let addr = SocketAddr::new(listener_conf.ip, config.server.port);
let options = ListenOptions {
ipv6_only: listener_conf.ip.is_ipv6(),
..Default::default()
};
match create_listener(addr, &options) {
Ok(socket) => {
let listener = TcpListener::from_std(socket.into())?;
info!("Listening on {}", addr);
let public_ip = if let Some(ip) = listener_conf.announce_ip {
ip
} else if listener_conf.ip.is_unspecified() {
if listener_conf.ip.is_ipv4() {
detected_ip.ipv4.unwrap_or(listener_conf.ip)
} else {
detected_ip.ipv6.unwrap_or(listener_conf.ip)
}
} else {
listener_conf.ip
};
if !config.show_link.is_empty() {
info!("--- Proxy Links ({}) ---", public_ip);
for user_name in &config.show_link {
if let Some(secret) = config.access.users.get(user_name) {
info!("User: {}", user_name);
if config.general.modes.classic {
info!(" Classic: tg://proxy?server={}&port={}&secret={}",
public_ip, config.server.port, secret);
}
if config.general.modes.secure {
info!(" DD: tg://proxy?server={}&port={}&secret=dd{}",
public_ip, config.server.port, secret);
}
if config.general.modes.tls {
let domain_hex = hex::encode(&config.censorship.tls_domain);
info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
public_ip, config.server.port, secret, domain_hex);
}
} else {
warn!("User '{}' in show_link not found", user_name);
}
}
info!("------------------------");
}
listeners.push(listener);
},
Err(e) => {
error!("Failed to bind to {}: {}", addr, e);
}
} }
} }
};
if listeners.is_empty() {
error!("No listeners. Exiting."); if let Err(e) = config.validate() {
std::process::exit(1); eprintln!("[telemt] Invalid config: {}", e);
std::process::exit(1);
}
let has_rust_log = std::env::var("RUST_LOG").is_ok();
let effective_log_level = if cli_silent {
LogLevel::Silent
} else if let Some(ref s) = cli_log_level {
LogLevel::from_str_loose(s)
} else {
config.general.log_level.clone()
};
let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new("info"));
tracing_subscriber::registry()
.with(filter_layer)
.with(fmt::Layer::default())
.init();
info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION"));
info!("Log level: {}", effective_log_level);
info!("Modes: classic={} secure={} tls={}",
config.general.modes.classic,
config.general.modes.secure,
config.general.modes.tls);
info!("TLS domain: {}", config.censorship.tls_domain);
if let Some(ref sock) = config.censorship.mask_unix_sock {
info!("Mask: {} -> unix:{}", config.censorship.mask, sock);
if !std::path::Path::new(sock).exists() {
warn!("Unix socket '{}' does not exist yet. Masking will fail until it appears.", sock);
} }
} else {
// Switch to user-configured log level after startup info!("Mask: {} -> {}:{}",
let runtime_filter = if has_rust_log { config.censorship.mask,
EnvFilter::from_default_env() config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain),
} else { config.censorship.mask_port);
EnvFilter::new(effective_log_level.to_filter_str()) }
};
filter_handle.reload(runtime_filter).expect("Failed to switch log filter"); if config.censorship.tls_domain == "www.google.com" {
warn!("Using default tls_domain. Consider setting a custom domain.");
for listener in listeners { }
let config = config.clone();
let stats = stats.clone(); let prefer_ipv6 = config.general.prefer_ipv6;
let upstream_manager = upstream_manager.clone(); let use_middle_proxy = config.general.use_middle_proxy;
let replay_checker = replay_checker.clone(); let config = Arc::new(config);
let buffer_pool = buffer_pool.clone(); let stats = Arc::new(Stats::new());
let rng = rng.clone(); let rng = Arc::new(SecureRandom::new());
let me_pool = me_pool.clone();
let replay_checker = Arc::new(ReplayChecker::new(
tokio::spawn(async move { config.access.replay_check_len,
loop { Duration::from_secs(config.access.replay_window_secs),
match listener.accept().await { ));
Ok((stream, peer_addr)) => {
let config = config.clone(); let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
let stats = stats.clone(); let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096));
let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone(); // Connection concurrency limit
let buffer_pool = buffer_pool.clone(); let _max_connections = Arc::new(Semaphore::new(10_000));
let rng = rng.clone();
let me_pool = me_pool.clone(); // =====================================================================
// Middle Proxy initialization (if enabled)
// =====================================================================
let me_pool: Option<Arc<MePool>> = if use_middle_proxy {
info!("=== Middle Proxy Mode ===");
// ad_tag (proxy_tag) for advertising
let proxy_tag = config.general.ad_tag.as_ref().map(|tag| {
hex::decode(tag).unwrap_or_else(|_| {
warn!("Invalid ad_tag hex, middle proxy ad_tag will be empty");
Vec::new()
})
});
// =============================================================
// CRITICAL: Download Telegram proxy-secret (NOT user secret!)
//
// C MTProxy uses TWO separate secrets:
// -S flag = 16-byte user secret for client obfuscation
// --aes-pwd = 32-512 byte binary file for ME RPC auth
//
// proxy-secret is from: https://core.telegram.org/getProxySecret
// =============================================================
let proxy_secret_path = config.general.proxy_secret_path.as_deref();
match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).await {
Ok(proxy_secret) => {
info!(
secret_len = proxy_secret.len(),
key_sig = format_args!("0x{:08x}",
if proxy_secret.len() >= 4 {
u32::from_le_bytes([proxy_secret[0], proxy_secret[1],
proxy_secret[2], proxy_secret[3]])
} else { 0 }),
"Proxy-secret loaded"
);
let pool = MePool::new(proxy_tag, proxy_secret);
match pool.init(2, &rng).await {
Ok(()) => {
info!("Middle-End pool initialized successfully");
// Phase 4: Start health monitor
let pool_clone = pool.clone();
let rng_clone = rng.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = ClientHandler::new( crate::transport::middle_proxy::me_health_monitor(
stream, peer_addr, config, stats, pool_clone, rng_clone, 2,
upstream_manager, replay_checker, buffer_pool, rng, ).await;
me_pool,
).run().await {
debug!(peer = %peer_addr, error = %e, "Connection error");
}
}); });
Some(pool)
} }
Err(e) => { Err(e) => {
error!("Accept error: {}", e); error!(error = %e, "Failed to initialize ME pool. Falling back to direct mode.");
tokio::time::sleep(Duration::from_millis(100)).await; None
} }
} }
} }
}); Err(e) => {
} error!(error = %e, "Failed to fetch proxy-secret. Falling back to direct mode.");
None
match signal::ctrl_c().await { }
Ok(()) => info!("Shutting down..."), }
Err(e) => error!("Signal error: {}", e), } else {
} None
};
Ok(())
if me_pool.is_some() {
info!("Transport: Middle Proxy (supports all DCs including CDN)");
} else {
info!("Transport: Direct TCP (standard DCs only)");
} }
// Startup DC ping (only meaningful in direct mode)
if me_pool.is_none() {
info!("================= Telegram DC Connectivity =================");
let ping_results = upstream_manager.ping_all_dcs(prefer_ipv6).await;
for upstream_result in &ping_results {
// Show which IP version is in use and which is fallback
if upstream_result.both_available {
if prefer_ipv6 {
info!(" IPv6 in use and IPv4 is fallback");
} else {
info!(" IPv4 in use and IPv6 is fallback");
}
} else {
let v6_works = upstream_result.v6_results.iter().any(|r| r.rtt_ms.is_some());
let v4_works = upstream_result.v4_results.iter().any(|r| r.rtt_ms.is_some());
if v6_works && !v4_works {
info!(" IPv6 only (IPv4 unavailable)");
} else if v4_works && !v6_works {
info!(" IPv4 only (IPv6 unavailable)");
} else if !v6_works && !v4_works {
info!(" No connectivity!");
}
}
info!(" via {}", upstream_result.upstream_name);
info!("============================================================");
// Print IPv6 results first
for dc in &upstream_result.v6_results {
let addr_str = format!("{}:{}", dc.dc_addr.ip(), dc.dc_addr.port());
match &dc.rtt_ms {
Some(rtt) => {
// Align: IPv6 addresses are longer, use fewer tabs
// [2001:b28:f23d:f001::a]:443 = ~28 chars
info!(" DC{} [IPv6] {}:\t\t{:.0} ms", dc.dc_idx, addr_str, rtt);
}
None => {
let err = dc.error.as_deref().unwrap_or("fail");
info!(" DC{} [IPv6] {}:\t\tFAIL ({})", dc.dc_idx, addr_str, err);
}
}
}
info!("============================================================");
// Print IPv4 results
for dc in &upstream_result.v4_results {
let addr_str = format!("{}:{}", dc.dc_addr.ip(), dc.dc_addr.port());
match &dc.rtt_ms {
Some(rtt) => {
// Align: IPv4 addresses are shorter, use more tabs
// 149.154.175.50:443 = ~18 chars
info!(" DC{} [IPv4] {}:\t\t\t\t{:.0} ms", dc.dc_idx, addr_str, rtt);
}
None => {
let err = dc.error.as_deref().unwrap_or("fail");
info!(" DC{} [IPv4] {}:\t\t\t\tFAIL ({})", dc.dc_idx, addr_str, err);
}
}
}
info!("============================================================");
}
}
// Background tasks
let um_clone = upstream_manager.clone();
tokio::spawn(async move { um_clone.run_health_checks(prefer_ipv6).await; });
let rc_clone = replay_checker.clone();
tokio::spawn(async move { rc_clone.run_periodic_cleanup().await; });
let detected_ip = detect_ip().await;
debug!("Detected IPs: v4={:?} v6={:?}", detected_ip.ipv4, detected_ip.ipv6);
let mut listeners = Vec::new();
for listener_conf in &config.server.listeners {
let addr = SocketAddr::new(listener_conf.ip, config.server.port);
let options = ListenOptions {
ipv6_only: listener_conf.ip.is_ipv6(),
..Default::default()
};
match create_listener(addr, &options) {
Ok(socket) => {
let listener = TcpListener::from_std(socket.into())?;
info!("Listening on {}", addr);
let public_ip = if let Some(ip) = listener_conf.announce_ip {
ip
} else if listener_conf.ip.is_unspecified() {
if listener_conf.ip.is_ipv4() {
detected_ip.ipv4.unwrap_or(listener_conf.ip)
} else {
detected_ip.ipv6.unwrap_or(listener_conf.ip)
}
} else {
listener_conf.ip
};
if !config.show_link.is_empty() {
info!("--- Proxy Links ({}) ---", public_ip);
for user_name in &config.show_link {
if let Some(secret) = config.access.users.get(user_name) {
info!("User: {}", user_name);
if config.general.modes.classic {
info!(" Classic: tg://proxy?server={}&port={}&secret={}",
public_ip, config.server.port, secret);
}
if config.general.modes.secure {
info!(" DD: tg://proxy?server={}&port={}&secret=dd{}",
public_ip, config.server.port, secret);
}
if config.general.modes.tls {
let domain_hex = hex::encode(&config.censorship.tls_domain);
info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
public_ip, config.server.port, secret, domain_hex);
}
} else {
warn!("User '{}' in show_link not found", user_name);
}
}
info!("------------------------");
}
listeners.push(listener);
},
Err(e) => {
error!("Failed to bind to {}: {}", addr, e);
}
}
}
if listeners.is_empty() {
error!("No listeners. Exiting.");
std::process::exit(1);
}
// Switch to user-configured log level after startup
let runtime_filter = if has_rust_log {
EnvFilter::from_default_env()
} else {
EnvFilter::new(effective_log_level.to_filter_str())
};
filter_handle.reload(runtime_filter).expect("Failed to switch log filter");
for listener in listeners {
let config = config.clone();
let stats = stats.clone();
let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
let me_pool = me_pool.clone();
tokio::spawn(async move {
loop {
match listener.accept().await {
Ok((stream, peer_addr)) => {
let config = config.clone();
let stats = stats.clone();
let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
let me_pool = me_pool.clone();
tokio::spawn(async move {
if let Err(e) = ClientHandler::new(
stream, peer_addr, config, stats,
upstream_manager, replay_checker, buffer_pool, rng,
me_pool,
).run().await {
debug!(peer = %peer_addr, error = %e, "Connection error");
}
});
}
Err(e) => {
error!("Accept error: {}", e);
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
}
});
}
match signal::ctrl_c().await {
Ok(()) => info!("Shutting down..."),
Err(e) => error!("Signal error: {}", e),
}
Ok(())
}

View File

@@ -19,9 +19,9 @@
use crate::crypto::{AesCtr, SecureRandom}; use crate::crypto::{AesCtr, SecureRandom};
use crate::proxy::handshake::{ use crate::proxy::handshake::{
handle_tls_handshake, handle_mtproto_handshake, handle_tls_handshake, handle_mtproto_handshake,
HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce, HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce_with_ciphers,
}; };
use crate::proxy::relay::relay_bidirectional; use crate::proxy::relay::relay_bidirectional;
use crate::proxy::masking::handle_bad_client; use crate::proxy::masking::handle_bad_client;
@@ -340,8 +340,8 @@
/// - ME returns responses in RPC_PROXY_ANS envelope /// - ME returns responses in RPC_PROXY_ANS envelope
async fn handle_via_middle_proxy<R, W>( async fn handle_via_middle_proxy<R, W>(
mut client_reader: CryptoReader<R>, crypto_reader: CryptoReader<R>,
mut client_writer: CryptoWriter<W>, crypto_writer: CryptoWriter<W>,
success: HandshakeSuccess, success: HandshakeSuccess,
me_pool: Arc<MePool>, me_pool: Arc<MePool>,
stats: Arc<Stats>, stats: Arc<Stats>,
@@ -352,7 +352,10 @@
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static,
{ {
let user = success.user.clone(); let mut client_reader = crypto_reader;
let mut client_writer = crypto_writer;
let user = success.user.clone();
let peer = success.peer; let peer = success.peer;
info!( info!(
@@ -635,8 +638,8 @@
rng, rng,
config.general.fast_mode, config.general.fast_mode,
); );
let encrypted_nonce = encrypt_tg_nonce(&nonce); let (encrypted_nonce, mut tg_encryptor, tg_decryptor) = encrypt_tg_nonce_with_ciphers(&nonce);
debug!( debug!(
peer = %success.peer, peer = %success.peer,
@@ -649,12 +652,9 @@
let (read_half, write_half) = stream.into_split(); let (read_half, write_half) = stream.into_split();
let decryptor = AesCtr::new(&tg_dec_key, tg_dec_iv);
let encryptor = AesCtr::new(&tg_enc_key, tg_enc_iv);
Ok(( Ok((
CryptoReader::new(read_half, decryptor), CryptoReader::new(read_half, tg_decryptor),
CryptoWriter::new(write_half, encryptor), CryptoWriter::new(write_half, tg_encryptor),
)) ))
} }
} }

View File

@@ -61,26 +61,26 @@ where
W: AsyncWrite + Unpin, W: AsyncWrite + Unpin,
{ {
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake"); debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
debug!(peer = %peer, "TLS handshake too short"); debug!(peer = %peer, "TLS handshake too short");
return HandshakeResult::BadClient { reader, writer }; return HandshakeResult::BadClient { reader, writer };
} }
let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]; let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN];
let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN]; let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN];
if replay_checker.check_tls_digest(digest_half) { if replay_checker.check_tls_digest(digest_half) {
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient { reader, writer }; return HandshakeResult::BadClient { reader, writer };
} }
let secrets: Vec<(String, Vec<u8>)> = config.access.users.iter() let secrets: Vec<(String, Vec<u8>)> = config.access.users.iter()
.filter_map(|(name, hex)| { .filter_map(|(name, hex)| {
hex::decode(hex).ok().map(|bytes| (name.clone(), bytes)) hex::decode(hex).ok().map(|bytes| (name.clone(), bytes))
}) })
.collect(); .collect();
let validation = match tls::validate_tls_handshake( let validation = match tls::validate_tls_handshake(
handshake, handshake,
&secrets, &secrets,
@@ -96,12 +96,12 @@ where
return HandshakeResult::BadClient { reader, writer }; return HandshakeResult::BadClient { reader, writer };
} }
}; };
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
Some((_, s)) => s, Some((_, s)) => s,
None => return HandshakeResult::BadClient { reader, writer }, None => return HandshakeResult::BadClient { reader, writer },
}; };
let response = tls::build_server_hello( let response = tls::build_server_hello(
secret, secret,
&validation.digest, &validation.digest,
@@ -109,27 +109,27 @@ where
config.censorship.fake_cert_len, config.censorship.fake_cert_len,
rng, rng,
); );
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello"); debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
if let Err(e) = writer.write_all(&response).await { if let Err(e) = writer.write_all(&response).await {
warn!(peer = %peer, error = %e, "Failed to write TLS ServerHello"); warn!(peer = %peer, error = %e, "Failed to write TLS ServerHello");
return HandshakeResult::Error(ProxyError::Io(e)); return HandshakeResult::Error(ProxyError::Io(e));
} }
if let Err(e) = writer.flush().await { if let Err(e) = writer.flush().await {
warn!(peer = %peer, error = %e, "Failed to flush TLS ServerHello"); warn!(peer = %peer, error = %e, "Failed to flush TLS ServerHello");
return HandshakeResult::Error(ProxyError::Io(e)); return HandshakeResult::Error(ProxyError::Io(e));
} }
replay_checker.add_tls_digest(digest_half); replay_checker.add_tls_digest(digest_half);
info!( info!(
peer = %peer, peer = %peer,
user = %validation.user, user = %validation.user,
"TLS handshake successful" "TLS handshake successful"
); );
HandshakeResult::Success(( HandshakeResult::Success((
FakeTlsReader::new(reader), FakeTlsReader::new(reader),
FakeTlsWriter::new(writer), FakeTlsWriter::new(writer),
@@ -152,75 +152,74 @@ where
W: AsyncWrite + Unpin + Send, W: AsyncWrite + Unpin + Send,
{ {
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes"); trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
if replay_checker.check_handshake(dec_prekey_iv) { if replay_checker.check_handshake(dec_prekey_iv) {
warn!(peer = %peer, "MTProto replay attack detected"); warn!(peer = %peer, "MTProto replay attack detected");
return HandshakeResult::BadClient { reader, writer }; return HandshakeResult::BadClient { reader, writer };
} }
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect(); let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
for (user, secret_hex) in &config.access.users { for (user, secret_hex) in &config.access.users {
let secret = match hex::decode(secret_hex) { let secret = match hex::decode(secret_hex) {
Ok(s) => s, Ok(s) => s,
Err(_) => continue, Err(_) => continue,
}; };
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey); dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret); dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input); let dec_key = sha256(&dec_key_input);
let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap()); let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
let mut decryptor = AesCtr::new(&dec_key, dec_iv); let mut decryptor = AesCtr::new(&dec_key, dec_iv);
let decrypted = decryptor.decrypt(handshake); let decrypted = decryptor.decrypt(handshake);
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
.try_into() .try_into()
.unwrap(); .unwrap();
let proto_tag = match ProtoTag::from_bytes(tag_bytes) { let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
Some(tag) => tag, Some(tag) => tag,
None => continue, None => continue,
}; };
let mode_ok = match proto_tag { let mode_ok = match proto_tag {
ProtoTag::Secure => { ProtoTag::Secure => {
if is_tls { config.general.modes.tls } else { config.general.modes.secure } if is_tls { config.general.modes.tls } else { config.general.modes.secure }
} }
ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic, ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic,
}; };
if !mode_ok { if !mode_ok {
debug!(peer = %peer, user = %user, proto = ?proto_tag, "Mode not enabled"); debug!(peer = %peer, user = %user, proto = ?proto_tag, "Mode not enabled");
continue; continue;
} }
let dc_idx = i16::from_le_bytes( let dc_idx = i16::from_le_bytes(
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap() decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
); );
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
enc_key_input.extend_from_slice(enc_prekey); enc_key_input.extend_from_slice(enc_prekey);
enc_key_input.extend_from_slice(&secret); enc_key_input.extend_from_slice(&secret);
let enc_key = sha256(&enc_key_input); let enc_key = sha256(&enc_key_input);
let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap()); let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap());
replay_checker.add_handshake(dec_prekey_iv); replay_checker.add_handshake(dec_prekey_iv);
let decryptor = AesCtr::new(&dec_key, dec_iv);
let encryptor = AesCtr::new(&enc_key, enc_iv); let encryptor = AesCtr::new(&enc_key, enc_iv);
let success = HandshakeSuccess { let success = HandshakeSuccess {
user: user.clone(), user: user.clone(),
dc_idx, dc_idx,
@@ -232,7 +231,7 @@ where
peer, peer,
is_tls, is_tls,
}; };
info!( info!(
peer = %peer, peer = %peer,
user = %user, user = %user,
@@ -241,14 +240,14 @@ where
tls = is_tls, tls = is_tls,
"MTProto handshake successful" "MTProto handshake successful"
); );
return HandshakeResult::Success(( return HandshakeResult::Success((
CryptoReader::new(reader, decryptor), CryptoReader::new(reader, decryptor),
CryptoWriter::new(writer, encryptor), CryptoWriter::new(writer, encryptor),
success, success,
)); ));
} }
debug!(peer = %peer, "MTProto handshake: no matching user found"); debug!(peer = %peer, "MTProto handshake: no matching user found");
HandshakeResult::BadClient { reader, writer } HandshakeResult::BadClient { reader, writer }
} }
@@ -265,88 +264,101 @@ pub fn generate_tg_nonce(
loop { loop {
let bytes = rng.bytes(HANDSHAKE_LEN); let bytes = rng.bytes(HANDSHAKE_LEN);
let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap(); let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap();
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; } if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; }
let first_four: [u8; 4] = nonce[..4].try_into().unwrap(); let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { continue; } if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { continue; }
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap(); let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; } if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; }
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
// CRITICAL: write dc_idx so upstream DC knows where to route // CRITICAL: write dc_idx so upstream DC knows where to route
nonce[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); nonce[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
if fast_mode { if fast_mode {
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN].copy_from_slice(client_dec_key); nonce[SKIP_LEN..SKIP_LEN + KEY_LEN].copy_from_slice(client_dec_key);
nonce[SKIP_LEN + KEY_LEN..SKIP_LEN + KEY_LEN + IV_LEN] nonce[SKIP_LEN + KEY_LEN..SKIP_LEN + KEY_LEN + IV_LEN]
.copy_from_slice(&client_dec_iv.to_be_bytes()); .copy_from_slice(&client_dec_iv.to_be_bytes());
} }
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect(); let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
let tg_enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap(); let tg_enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
let tg_enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap()); let tg_enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap(); let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap();
let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap()); let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap());
return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv); return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv);
} }
} }
/// Encrypt nonce for sending to Telegram /// Encrypt nonce for sending to Telegram and return cipher objects with correct counter state
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> { pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec<u8>, AesCtr, AesCtr) {
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap(); let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
let iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
let enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
let mut encryptor = AesCtr::new(&key, iv); let enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
let encrypted_full = encryptor.encrypt(nonce);
let dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap();
let dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap());
let mut encryptor = AesCtr::new(&enc_key, enc_iv);
let encrypted_full = encryptor.encrypt(nonce); // counter: 0 → 4
let mut result = nonce[..PROTO_TAG_POS].to_vec(); let mut result = nonce[..PROTO_TAG_POS].to_vec();
result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]); result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]);
result let decryptor = AesCtr::new(&dec_key, dec_iv);
(result, encryptor, decryptor)
}
/// Encrypt nonce for sending to Telegram (legacy function for compatibility)
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(nonce);
encrypted
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn test_generate_tg_nonce() { fn test_generate_tg_nonce() {
let client_dec_key = [0x42u8; 32]; let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128; let client_dec_iv = 12345u128;
let rng = SecureRandom::new(); let rng = SecureRandom::new();
let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) =
generate_tg_nonce(ProtoTag::Secure, 2, &client_dec_key, client_dec_iv, &rng, false); generate_tg_nonce(ProtoTag::Secure, 2, &client_dec_key, client_dec_iv, &rng, false);
assert_eq!(nonce.len(), HANDSHAKE_LEN); assert_eq!(nonce.len(), HANDSHAKE_LEN);
let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap(); let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap();
assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure)); assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure));
} }
#[test] #[test]
fn test_encrypt_tg_nonce() { fn test_encrypt_tg_nonce() {
let client_dec_key = [0x42u8; 32]; let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128; let client_dec_iv = 12345u128;
let rng = SecureRandom::new(); let rng = SecureRandom::new();
let (nonce, _, _, _, _) = let (nonce, _, _, _, _) =
generate_tg_nonce(ProtoTag::Secure, 2, &client_dec_key, client_dec_iv, &rng, false); generate_tg_nonce(ProtoTag::Secure, 2, &client_dec_key, client_dec_iv, &rng, false);
let encrypted = encrypt_tg_nonce(&nonce); let encrypted = encrypt_tg_nonce(&nonce);
assert_eq!(encrypted.len(), HANDSHAKE_LEN); assert_eq!(encrypted.len(), HANDSHAKE_LEN);
assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]); assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]);
assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]); assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]);
} }
#[test] #[test]
fn test_handshake_success_zeroize_on_drop() { fn test_handshake_success_zeroize_on_drop() {
let success = HandshakeSuccess { let success = HandshakeSuccess {
@@ -360,10 +372,10 @@ mod tests {
peer: "127.0.0.1:1234".parse().unwrap(), peer: "127.0.0.1:1234".parse().unwrap(),
is_tls: true, is_tls: true,
}; };
assert_eq!(success.dec_key, [0xAA; 32]); assert_eq!(success.dec_key, [0xAA; 32]);
assert_eq!(success.enc_key, [0xCC; 32]); assert_eq!(success.enc_key, [0xCC; 32]);
drop(success); drop(success);
// Drop impl zeroizes key material without panic // Drop impl zeroizes key material without panic
} }

View File

@@ -1,4 +1,6 @@
//! Upstream Management with per-DC latency-weighted selection //! Upstream Management with per-DC latency-weighted selection
//!
//! IPv6/IPv4 connectivity checks with configurable preference.
use std::net::{SocketAddr, IpAddr}; use std::net::{SocketAddr, IpAddr};
use std::sync::Arc; use std::sync::Arc;
@@ -18,6 +20,9 @@ use crate::transport::socks::{connect_socks4, connect_socks5};
/// Number of Telegram datacenters /// Number of Telegram datacenters
const NUM_DCS: usize = 5; const NUM_DCS: usize = 5;
/// Timeout for individual DC ping attempt
const DC_PING_TIMEOUT_SECS: u64 = 5;
// ============= RTT Tracking ============= // ============= RTT Tracking =============
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
@@ -30,19 +35,42 @@ impl LatencyEma {
const fn new(alpha: f64) -> Self { const fn new(alpha: f64) -> Self {
Self { value_ms: None, alpha } Self { value_ms: None, alpha }
} }
fn update(&mut self, sample_ms: f64) { fn update(&mut self, sample_ms: f64) {
self.value_ms = Some(match self.value_ms { self.value_ms = Some(match self.value_ms {
None => sample_ms, None => sample_ms,
Some(prev) => prev * (1.0 - self.alpha) + sample_ms * self.alpha, Some(prev) => prev * (1.0 - self.alpha) + sample_ms * self.alpha,
}); });
} }
fn get(&self) -> Option<f64> { fn get(&self) -> Option<f64> {
self.value_ms self.value_ms
} }
} }
// ============= Per-DC IP Preference Tracking =============
/// Tracks which IP version works for each DC
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IpPreference {
/// Not yet tested
Unknown,
/// IPv6 works
PreferV6,
/// Only IPv4 works (IPv6 failed)
PreferV4,
/// Both work
BothWork,
/// Both failed
Unavailable,
}
impl Default for IpPreference {
fn default() -> Self {
Self::Unknown
}
}
// ============= Upstream State ============= // ============= Upstream State =============
#[derive(Debug)] #[derive(Debug)]
@@ -53,6 +81,8 @@ struct UpstreamState {
last_check: std::time::Instant, last_check: std::time::Instant,
/// Per-DC latency EMA (index 0 = DC1, index 4 = DC5) /// Per-DC latency EMA (index 0 = DC1, index 4 = DC5)
dc_latency: [LatencyEma; NUM_DCS], dc_latency: [LatencyEma; NUM_DCS],
/// Per-DC IP version preference (learned from connectivity tests)
dc_ip_pref: [IpPreference; NUM_DCS],
} }
impl UpstreamState { impl UpstreamState {
@@ -63,16 +93,11 @@ impl UpstreamState {
fails: 0, fails: 0,
last_check: std::time::Instant::now(), last_check: std::time::Instant::now(),
dc_latency: [LatencyEma::new(0.3); NUM_DCS], dc_latency: [LatencyEma::new(0.3); NUM_DCS],
dc_ip_pref: [IpPreference::Unknown; NUM_DCS],
} }
} }
/// Map DC index to latency array slot (0..NUM_DCS). /// Map DC index to latency array slot (0..NUM_DCS).
///
/// Matches the C implementation's `mf_cluster_lookup` behavior:
/// - Standard DCs ±1..±5 → direct mapping to array index 0..4
/// - Unknown DCs (CDN, media, etc.) → default DC slot (index 1 = DC 2)
/// This matches Telegram's `default 2;` in proxy-multi.conf.
/// - There is NO modular arithmetic in the C implementation.
fn dc_array_idx(dc_idx: i16) -> Option<usize> { fn dc_array_idx(dc_idx: i16) -> Option<usize> {
let abs_dc = dc_idx.unsigned_abs() as usize; let abs_dc = dc_idx.unsigned_abs() as usize;
if abs_dc == 0 { if abs_dc == 0 {
@@ -82,25 +107,22 @@ impl UpstreamState {
Some(abs_dc - 1) Some(abs_dc - 1)
} else { } else {
// Unknown DC → default cluster (DC 2, index 1) // Unknown DC → default cluster (DC 2, index 1)
// Same as C: mf_cluster_lookup returns default_cluster
Some(1) Some(1)
} }
} }
/// Get latency for a specific DC, falling back to average across all known DCs /// Get latency for a specific DC, falling back to average across all known DCs
fn effective_latency(&self, dc_idx: Option<i16>) -> Option<f64> { fn effective_latency(&self, dc_idx: Option<i16>) -> Option<f64> {
// Try DC-specific latency first
if let Some(di) = dc_idx.and_then(Self::dc_array_idx) { if let Some(di) = dc_idx.and_then(Self::dc_array_idx) {
if let Some(ms) = self.dc_latency[di].get() { if let Some(ms) = self.dc_latency[di].get() {
return Some(ms); return Some(ms);
} }
} }
// Fallback: average of all known DC latencies
let (sum, count) = self.dc_latency.iter() let (sum, count) = self.dc_latency.iter()
.filter_map(|l| l.get()) .filter_map(|l| l.get())
.fold((0.0, 0u32), |(s, c), v| (s + v, c + 1)); .fold((0.0, 0u32), |(s, c), v| (s + v, c + 1));
if count > 0 { Some(sum / count as f64) } else { None } if count > 0 { Some(sum / count as f64) } else { None }
} }
} }
@@ -114,11 +136,14 @@ pub struct DcPingResult {
pub error: Option<String>, pub error: Option<String>,
} }
/// Result of startup ping for one upstream /// Result of startup ping for one upstream (separate v6/v4 results)
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct StartupPingResult { pub struct StartupPingResult {
pub results: Vec<DcPingResult>, pub v6_results: Vec<DcPingResult>,
pub v4_results: Vec<DcPingResult>,
pub upstream_name: String, pub upstream_name: String,
/// True if both IPv6 and IPv4 have at least one working DC
pub both_available: bool,
} }
// ============= Upstream Manager ============= // ============= Upstream Manager =============
@@ -134,22 +159,13 @@ impl UpstreamManager {
.filter(|c| c.enabled) .filter(|c| c.enabled)
.map(UpstreamState::new) .map(UpstreamState::new)
.collect(); .collect();
Self { Self {
upstreams: Arc::new(RwLock::new(states)), upstreams: Arc::new(RwLock::new(states)),
} }
} }
/// Select upstream using latency-weighted random selection. /// Select upstream using latency-weighted random selection.
///
/// `effective_weight = config_weight × latency_factor`
///
/// where `latency_factor = 1000 / latency_ms` if latency is known,
/// or `1.0` if no latency data is available.
///
/// This means a 50ms upstream gets factor 20, a 200ms upstream gets
/// factor 5 — the faster route is 4× more likely to be chosen
/// (all else being equal).
async fn select_upstream(&self, dc_idx: Option<i16>) -> Option<usize> { async fn select_upstream(&self, dc_idx: Option<i16>) -> Option<usize> {
let upstreams = self.upstreams.read().await; let upstreams = self.upstreams.read().await;
if upstreams.is_empty() { if upstreams.is_empty() {
@@ -161,34 +177,32 @@ impl UpstreamManager {
.filter(|(_, u)| u.healthy) .filter(|(_, u)| u.healthy)
.map(|(i, _)| i) .map(|(i, _)| i)
.collect(); .collect();
if healthy.is_empty() { if healthy.is_empty() {
// All unhealthy — pick any
return Some(rand::rng().gen_range(0..upstreams.len())); return Some(rand::rng().gen_range(0..upstreams.len()));
} }
if healthy.len() == 1 { if healthy.len() == 1 {
return Some(healthy[0]); return Some(healthy[0]);
} }
// Calculate latency-weighted scores
let weights: Vec<(usize, f64)> = healthy.iter().map(|&i| { let weights: Vec<(usize, f64)> = healthy.iter().map(|&i| {
let base = upstreams[i].config.weight as f64; let base = upstreams[i].config.weight as f64;
let latency_factor = upstreams[i].effective_latency(dc_idx) let latency_factor = upstreams[i].effective_latency(dc_idx)
.map(|ms| if ms > 1.0 { 1000.0 / ms } else { 1000.0 }) .map(|ms| if ms > 1.0 { 1000.0 / ms } else { 1000.0 })
.unwrap_or(1.0); .unwrap_or(1.0);
(i, base * latency_factor) (i, base * latency_factor)
}).collect(); }).collect();
let total: f64 = weights.iter().map(|(_, w)| w).sum(); let total: f64 = weights.iter().map(|(_, w)| w).sum();
if total <= 0.0 { if total <= 0.0 {
return Some(healthy[rand::rng().gen_range(0..healthy.len())]); return Some(healthy[rand::rng().gen_range(0..healthy.len())]);
} }
let mut choice: f64 = rand::rng().gen_range(0.0..total); let mut choice: f64 = rand::rng().gen_range(0.0..total);
for &(idx, weight) in &weights { for &(idx, weight) in &weights {
if choice < weight { if choice < weight {
trace!( trace!(
@@ -202,25 +216,22 @@ impl UpstreamManager {
} }
choice -= weight; choice -= weight;
} }
Some(healthy[0]) Some(healthy[0])
} }
/// Connect to target through a selected upstream. /// Connect to target through a selected upstream.
///
/// `dc_idx` is used for latency-based upstream selection and RTT tracking.
/// Pass `None` if DC index is unknown.
pub async fn connect(&self, target: SocketAddr, dc_idx: Option<i16>) -> Result<TcpStream> { pub async fn connect(&self, target: SocketAddr, dc_idx: Option<i16>) -> Result<TcpStream> {
let idx = self.select_upstream(dc_idx).await let idx = self.select_upstream(dc_idx).await
.ok_or_else(|| ProxyError::Config("No upstreams available".to_string()))?; .ok_or_else(|| ProxyError::Config("No upstreams available".to_string()))?;
let upstream = { let upstream = {
let guard = self.upstreams.read().await; let guard = self.upstreams.read().await;
guard[idx].config.clone() guard[idx].config.clone()
}; };
let start = Instant::now(); let start = Instant::now();
match self.connect_via_upstream(&upstream, target).await { match self.connect_via_upstream(&upstream, target).await {
Ok(stream) => { Ok(stream) => {
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0; let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
@@ -231,8 +242,7 @@ impl UpstreamManager {
} }
u.healthy = true; u.healthy = true;
u.fails = 0; u.fails = 0;
// Store per-DC latency
if let Some(di) = dc_idx.and_then(UpstreamState::dc_array_idx) { if let Some(di) = dc_idx.and_then(UpstreamState::dc_array_idx) {
u.dc_latency[di].update(rtt_ms); u.dc_latency[di].update(rtt_ms);
} }
@@ -253,92 +263,93 @@ impl UpstreamManager {
} }
} }
} }
async fn connect_via_upstream(&self, config: &UpstreamConfig, target: SocketAddr) -> Result<TcpStream> { async fn connect_via_upstream(&self, config: &UpstreamConfig, target: SocketAddr) -> Result<TcpStream> {
match &config.upstream_type { match &config.upstream_type {
UpstreamType::Direct { interface } => { UpstreamType::Direct { interface } => {
let bind_ip = interface.as_ref() let bind_ip = interface.as_ref()
.and_then(|s| s.parse::<IpAddr>().ok()); .and_then(|s| s.parse::<IpAddr>().ok());
let socket = create_outgoing_socket_bound(target, bind_ip)?; let socket = create_outgoing_socket_bound(target, bind_ip)?;
socket.set_nonblocking(true)?; socket.set_nonblocking(true)?;
match socket.connect(&target.into()) { match socket.connect(&target.into()) {
Ok(()) => {}, Ok(()) => {},
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {}, Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {},
Err(err) => return Err(ProxyError::Io(err)), Err(err) => return Err(ProxyError::Io(err)),
} }
let std_stream: std::net::TcpStream = socket.into(); let std_stream: std::net::TcpStream = socket.into();
let stream = TcpStream::from_std(std_stream)?; let stream = TcpStream::from_std(std_stream)?;
stream.writable().await?; stream.writable().await?;
if let Some(e) = stream.take_error()? { if let Some(e) = stream.take_error()? {
return Err(ProxyError::Io(e)); return Err(ProxyError::Io(e));
} }
Ok(stream) Ok(stream)
}, },
UpstreamType::Socks4 { address, interface, user_id } => { UpstreamType::Socks4 { address, interface, user_id } => {
let proxy_addr: SocketAddr = address.parse() let proxy_addr: SocketAddr = address.parse()
.map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?; .map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?;
let bind_ip = interface.as_ref() let bind_ip = interface.as_ref()
.and_then(|s| s.parse::<IpAddr>().ok()); .and_then(|s| s.parse::<IpAddr>().ok());
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?; let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
socket.set_nonblocking(true)?; socket.set_nonblocking(true)?;
match socket.connect(&proxy_addr.into()) { match socket.connect(&proxy_addr.into()) {
Ok(()) => {}, Ok(()) => {},
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {}, Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {},
Err(err) => return Err(ProxyError::Io(err)), Err(err) => return Err(ProxyError::Io(err)),
} }
let std_stream: std::net::TcpStream = socket.into(); let std_stream: std::net::TcpStream = socket.into();
let mut stream = TcpStream::from_std(std_stream)?; let mut stream = TcpStream::from_std(std_stream)?;
stream.writable().await?; stream.writable().await?;
if let Some(e) = stream.take_error()? { if let Some(e) = stream.take_error()? {
return Err(ProxyError::Io(e)); return Err(ProxyError::Io(e));
} }
connect_socks4(&mut stream, target, user_id.as_deref()).await?; connect_socks4(&mut stream, target, user_id.as_deref()).await?;
Ok(stream) Ok(stream)
}, },
UpstreamType::Socks5 { address, interface, username, password } => { UpstreamType::Socks5 { address, interface, username, password } => {
let proxy_addr: SocketAddr = address.parse() let proxy_addr: SocketAddr = address.parse()
.map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?; .map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?;
let bind_ip = interface.as_ref() let bind_ip = interface.as_ref()
.and_then(|s| s.parse::<IpAddr>().ok()); .and_then(|s| s.parse::<IpAddr>().ok());
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?; let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
socket.set_nonblocking(true)?; socket.set_nonblocking(true)?;
match socket.connect(&proxy_addr.into()) { match socket.connect(&proxy_addr.into()) {
Ok(()) => {}, Ok(()) => {},
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {}, Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {},
Err(err) => return Err(ProxyError::Io(err)), Err(err) => return Err(ProxyError::Io(err)),
} }
let std_stream: std::net::TcpStream = socket.into(); let std_stream: std::net::TcpStream = socket.into();
let mut stream = TcpStream::from_std(std_stream)?; let mut stream = TcpStream::from_std(std_stream)?;
stream.writable().await?; stream.writable().await?;
if let Some(e) = stream.take_error()? { if let Some(e) = stream.take_error()? {
return Err(ProxyError::Io(e)); return Err(ProxyError::Io(e));
} }
connect_socks5(&mut stream, target, username.as_deref(), password.as_deref()).await?; connect_socks5(&mut stream, target, username.as_deref(), password.as_deref()).await?;
Ok(stream) Ok(stream)
}, },
} }
} }
// ============= Startup Ping ============= // ============= Startup Ping (test both IPv6 and IPv4) =============
/// Ping all Telegram DCs through all upstreams. /// Ping all Telegram DCs through all upstreams.
/// Tests BOTH IPv6 and IPv4, returns separate results for each.
pub async fn ping_all_dcs(&self, prefer_ipv6: bool) -> Vec<StartupPingResult> { pub async fn ping_all_dcs(&self, prefer_ipv6: bool) -> Vec<StartupPingResult> {
let upstreams: Vec<(usize, UpstreamConfig)> = { let upstreams: Vec<(usize, UpstreamConfig)> = {
let guard = self.upstreams.read().await; let guard = self.upstreams.read().await;
@@ -346,11 +357,9 @@ impl UpstreamManager {
.map(|(i, u)| (i, u.config.clone())) .map(|(i, u)| (i, u.config.clone()))
.collect() .collect()
}; };
let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 };
let mut all_results = Vec::new(); let mut all_results = Vec::new();
for (upstream_idx, upstream_config) in &upstreams { for (upstream_idx, upstream_config) in &upstreams {
let upstream_name = match &upstream_config.upstream_type { let upstream_name = match &upstream_config.upstream_type {
UpstreamType::Direct { interface } => { UpstreamType::Direct { interface } => {
@@ -359,130 +368,260 @@ impl UpstreamManager {
UpstreamType::Socks4 { address, .. } => format!("socks4://{}", address), UpstreamType::Socks4 { address, .. } => format!("socks4://{}", address),
UpstreamType::Socks5 { address, .. } => format!("socks5://{}", address), UpstreamType::Socks5 { address, .. } => format!("socks5://{}", address),
}; };
let mut dc_results = Vec::new(); let mut v6_results = Vec::new();
let mut v4_results = Vec::new();
for (dc_zero_idx, dc_ip) in datacenters.iter().enumerate() {
let dc_addr = SocketAddr::new(*dc_ip, TG_DATACENTER_PORT); // === Ping IPv6 first ===
for dc_zero_idx in 0..NUM_DCS {
let ping_result = tokio::time::timeout( let dc_v6 = TG_DATACENTERS_V6[dc_zero_idx];
Duration::from_secs(5), let addr_v6 = SocketAddr::new(dc_v6, TG_DATACENTER_PORT);
self.ping_single_dc(upstream_config, dc_addr)
let result = tokio::time::timeout(
Duration::from_secs(DC_PING_TIMEOUT_SECS),
self.ping_single_dc(&upstream_config, addr_v6)
).await; ).await;
let result = match ping_result { let ping_result = match result {
Ok(Ok(rtt_ms)) => { Ok(Ok(rtt_ms)) => {
// Store per-DC latency
let mut guard = self.upstreams.write().await; let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(*upstream_idx) { if let Some(u) = guard.get_mut(*upstream_idx) {
u.dc_latency[dc_zero_idx].update(rtt_ms); u.dc_latency[dc_zero_idx].update(rtt_ms);
} }
DcPingResult { DcPingResult {
dc_idx: dc_zero_idx + 1, dc_idx: dc_zero_idx + 1,
dc_addr, dc_addr: addr_v6,
rtt_ms: Some(rtt_ms), rtt_ms: Some(rtt_ms),
error: None, error: None,
} }
} }
Ok(Err(e)) => DcPingResult { Ok(Err(e)) => DcPingResult {
dc_idx: dc_zero_idx + 1, dc_idx: dc_zero_idx + 1,
dc_addr, dc_addr: addr_v6,
rtt_ms: None, rtt_ms: None,
error: Some(e.to_string()), error: Some(e.to_string()),
}, },
Err(_) => DcPingResult { Err(_) => DcPingResult {
dc_idx: dc_zero_idx + 1, dc_idx: dc_zero_idx + 1,
dc_addr, dc_addr: addr_v6,
rtt_ms: None, rtt_ms: None,
error: Some("timeout (5s)".to_string()), error: Some("timeout".to_string()),
}, },
}; };
v6_results.push(ping_result);
dc_results.push(result);
} }
// === Then ping IPv4 ===
for dc_zero_idx in 0..NUM_DCS {
let dc_v4 = TG_DATACENTERS_V4[dc_zero_idx];
let addr_v4 = SocketAddr::new(dc_v4, TG_DATACENTER_PORT);
let result = tokio::time::timeout(
Duration::from_secs(DC_PING_TIMEOUT_SECS),
self.ping_single_dc(&upstream_config, addr_v4)
).await;
let ping_result = match result {
Ok(Ok(rtt_ms)) => {
let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(*upstream_idx) {
u.dc_latency[dc_zero_idx].update(rtt_ms);
}
DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v4,
rtt_ms: Some(rtt_ms),
error: None,
}
}
Ok(Err(e)) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v4,
rtt_ms: None,
error: Some(e.to_string()),
},
Err(_) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v4,
rtt_ms: None,
error: Some("timeout".to_string()),
},
};
v4_results.push(ping_result);
}
// Check if both IP versions have at least one working DC
let v6_has_working = v6_results.iter().any(|r| r.rtt_ms.is_some());
let v4_has_working = v4_results.iter().any(|r| r.rtt_ms.is_some());
let both_available = v6_has_working && v4_has_working;
// Update IP preference for each DC
{
let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(*upstream_idx) {
for dc_zero_idx in 0..NUM_DCS {
let v6_ok = v6_results[dc_zero_idx].rtt_ms.is_some();
let v4_ok = v4_results[dc_zero_idx].rtt_ms.is_some();
u.dc_ip_pref[dc_zero_idx] = match (v6_ok, v4_ok) {
(true, true) => IpPreference::BothWork,
(true, false) => IpPreference::PreferV6,
(false, true) => IpPreference::PreferV4,
(false, false) => IpPreference::Unavailable,
};
}
}
}
all_results.push(StartupPingResult { all_results.push(StartupPingResult {
results: dc_results, v6_results,
v4_results,
upstream_name, upstream_name,
both_available,
}); });
} }
all_results all_results
} }
async fn ping_single_dc(&self, config: &UpstreamConfig, target: SocketAddr) -> Result<f64> { async fn ping_single_dc(&self, config: &UpstreamConfig, target: SocketAddr) -> Result<f64> {
let start = Instant::now(); let start = Instant::now();
let _stream = self.connect_via_upstream(config, target).await?; let _stream = self.connect_via_upstream(config, target).await?;
Ok(start.elapsed().as_secs_f64() * 1000.0) Ok(start.elapsed().as_secs_f64() * 1000.0)
} }
// ============= Health Checks ============= // ============= Health Checks =============
/// Background health check: rotates through DCs, 30s interval. /// Background health check: rotates through DCs, 30s interval.
/// Uses preferred IP version based on config.
pub async fn run_health_checks(&self, prefer_ipv6: bool) { pub async fn run_health_checks(&self, prefer_ipv6: bool) {
let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 };
let mut dc_rotation = 0usize; let mut dc_rotation = 0usize;
loop { loop {
tokio::time::sleep(Duration::from_secs(30)).await; tokio::time::sleep(Duration::from_secs(30)).await;
let dc_zero_idx = dc_rotation % datacenters.len(); let dc_zero_idx = dc_rotation % NUM_DCS;
dc_rotation += 1; dc_rotation += 1;
let check_target = SocketAddr::new(datacenters[dc_zero_idx], TG_DATACENTER_PORT); let dc_addr = if prefer_ipv6 {
SocketAddr::new(TG_DATACENTERS_V6[dc_zero_idx], TG_DATACENTER_PORT)
} else {
SocketAddr::new(TG_DATACENTERS_V4[dc_zero_idx], TG_DATACENTER_PORT)
};
let fallback_addr = if prefer_ipv6 {
SocketAddr::new(TG_DATACENTERS_V4[dc_zero_idx], TG_DATACENTER_PORT)
} else {
SocketAddr::new(TG_DATACENTERS_V6[dc_zero_idx], TG_DATACENTER_PORT)
};
let count = self.upstreams.read().await.len(); let count = self.upstreams.read().await.len();
for i in 0..count { for i in 0..count {
let config = { let config = {
let guard = self.upstreams.read().await; let guard = self.upstreams.read().await;
guard[i].config.clone() guard[i].config.clone()
}; };
let start = Instant::now(); let start = Instant::now();
let result = tokio::time::timeout( let result = tokio::time::timeout(
Duration::from_secs(10), Duration::from_secs(10),
self.connect_via_upstream(&config, check_target) self.connect_via_upstream(&config, dc_addr)
).await; ).await;
let mut guard = self.upstreams.write().await;
let u = &mut guard[i];
match result { match result {
Ok(Ok(_stream)) => { Ok(Ok(_stream)) => {
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0; let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
let mut guard = self.upstreams.write().await;
let u = &mut guard[i];
u.dc_latency[dc_zero_idx].update(rtt_ms); u.dc_latency[dc_zero_idx].update(rtt_ms);
if !u.healthy { if !u.healthy {
info!( info!(
rtt = format!("{:.0}ms", rtt_ms), rtt = format!("{:.0} ms", rtt_ms),
dc = dc_zero_idx + 1, dc = dc_zero_idx + 1,
"Upstream recovered" "Upstream recovered"
); );
} }
u.healthy = true; u.healthy = true;
u.fails = 0; u.fails = 0;
u.last_check = std::time::Instant::now();
} }
Ok(Err(e)) => { Ok(Err(_)) | Err(_) => {
u.fails += 1; // Try fallback
debug!(dc = dc_zero_idx + 1, fails = u.fails, debug!(dc = dc_zero_idx + 1, "Health check failed, trying fallback");
"Health check failed: {}", e);
if u.fails > 3 { let start2 = Instant::now();
u.healthy = false; let result2 = tokio::time::timeout(
warn!("Upstream unhealthy (fails)"); Duration::from_secs(10),
} self.connect_via_upstream(&config, fallback_addr)
} ).await;
Err(_) => {
u.fails += 1; let mut guard = self.upstreams.write().await;
debug!(dc = dc_zero_idx + 1, fails = u.fails, let u = &mut guard[i];
"Health check timeout");
if u.fails > 3 { match result2 {
u.healthy = false; Ok(Ok(_stream)) => {
warn!("Upstream unhealthy (timeout)"); let rtt_ms = start2.elapsed().as_secs_f64() * 1000.0;
u.dc_latency[dc_zero_idx].update(rtt_ms);
if !u.healthy {
info!(
rtt = format!("{:.0} ms", rtt_ms),
dc = dc_zero_idx + 1,
"Upstream recovered (fallback)"
);
}
u.healthy = true;
u.fails = 0;
}
Ok(Err(e)) => {
u.fails += 1;
debug!(dc = dc_zero_idx + 1, fails = u.fails,
"Health check failed (both): {}", e);
if u.fails > 3 {
u.healthy = false;
warn!("Upstream unhealthy (fails)");
}
}
Err(_) => {
u.fails += 1;
debug!(dc = dc_zero_idx + 1, fails = u.fails,
"Health check timeout (both)");
if u.fails > 3 {
u.healthy = false;
warn!("Upstream unhealthy (timeout)");
}
}
} }
u.last_check = std::time::Instant::now();
} }
} }
u.last_check = std::time::Instant::now();
} }
} }
} }
/// Get the preferred IP for a DC (for use by other components)
pub async fn get_dc_ip_preference(&self, dc_idx: i16) -> Option<IpPreference> {
let guard = self.upstreams.read().await;
if guard.is_empty() {
return None;
}
UpstreamState::dc_array_idx(dc_idx)
.map(|idx| guard[0].dc_ip_pref[idx])
}
/// Get preferred DC address based on config preference
pub async fn get_dc_addr(&self, dc_idx: i16, prefer_ipv6: bool) -> Option<SocketAddr> {
let arr_idx = UpstreamState::dc_array_idx(dc_idx)?;
let ip = if prefer_ipv6 {
TG_DATACENTERS_V6[arr_idx]
} else {
TG_DATACENTERS_V4[arr_idx]
};
Some(SocketAddr::new(ip, TG_DATACENTER_PORT))
}
} }