Merge pull request #21 from telemt/1.2.0.0

1.2.0.0
This commit is contained in:
Alexey
2026-02-11 17:00:46 +03:00
committed by GitHub
22 changed files with 1322 additions and 581 deletions

View File

@@ -14,6 +14,11 @@ jobs:
name: Build name: Build
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
contents: read
actions: write
checks: write
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4

View File

@@ -1,15 +1,14 @@
[package] [package]
name = "telemt" name = "telemt"
version = "1.0.0" version = "1.2.0"
edition = "2021" edition = "2024"
rust-version = "1.75"
[dependencies] [dependencies]
# C # C
libc = "0.2" libc = "0.2"
# Async runtime # Async runtime
tokio = { version = "1.35", features = ["full", "tracing"] } tokio = { version = "1.42", features = ["full", "tracing"] }
tokio-util = { version = "0.7", features = ["codec"] } tokio-util = { version = "0.7", features = ["codec"] }
# Crypto # Crypto
@@ -20,41 +19,41 @@ sha2 = "0.10"
sha1 = "0.10" sha1 = "0.10"
md-5 = "0.10" md-5 = "0.10"
hmac = "0.12" hmac = "0.12"
crc32fast = "1.3" crc32fast = "1.4"
zeroize = { version = "1.8", features = ["derive"] }
# Network # Network
socket2 = { version = "0.5", features = ["all"] } socket2 = { version = "0.5", features = ["all"] }
rustls = "0.22"
# Serial # Serialization
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
toml = "0.8" toml = "0.8"
# Utils # Utils
bytes = "1.5" bytes = "1.9"
thiserror = "1.0" thiserror = "2.0"
tracing = "0.1" tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] }
parking_lot = "0.12" parking_lot = "0.12"
dashmap = "5.5" dashmap = "5.5"
lru = "0.12" lru = "0.12"
rand = "0.8" rand = "0.9"
chrono = { version = "0.4", features = ["serde"] } chrono = { version = "0.4", features = ["serde"] }
hex = "0.4" hex = "0.4"
base64 = "0.21" base64 = "0.22"
url = "2.5" url = "2.5"
regex = "1.10" regex = "1.11"
once_cell = "1.19"
crossbeam-queue = "0.3" crossbeam-queue = "0.3"
# HTTP # HTTP
reqwest = { version = "0.11", features = ["rustls-tls"], default-features = false } reqwest = { version = "0.12", features = ["rustls-tls"], default-features = false }
[dev-dependencies] [dev-dependencies]
tokio-test = "0.4" tokio-test = "0.4"
criterion = "0.5" criterion = "0.5"
proptest = "1.4" proptest = "1.4"
futures = "0.3"
[[bench]] [[bench]]
name = "crypto_bench" name = "crypto_bench"

View File

@@ -6,8 +6,13 @@ show_link = ["hello"]
[general] [general]
prefer_ipv6 = false prefer_ipv6 = false
fast_mode = true fast_mode = true
use_middle_proxy = false use_middle_proxy = true
# ad_tag = "..." ad_tag = "00000000000000000000000000000000"
# Log level: debug | verbose | normal | silent
# Can be overridden with --silent or --log-level CLI flags
# RUST_LOG env var takes absolute priority over all of these
log_level = "normal"
[general.modes] [general.modes]
classic = false classic = false
@@ -39,16 +44,16 @@ client_ack = 300
# === Anti-Censorship & Masking === # === Anti-Censorship & Masking ===
[censorship] [censorship]
tls_domain = "petrovich.ru" tls_domain = "google.ru"
mask = true mask = true
mask_port = 443 mask_port = 443
# mask_host = "petrovich.ru" # Defaults to tls_domain if not set # mask_host = "petrovich.ru" # Defaults to tls_domain if not set
fake_cert_len = 2048 fake_cert_len = 2048
# === Access Control & Users === # === Access Control & Users ===
# username "hello" is used for example
[access] [access]
replay_check_len = 65536 replay_check_len = 65536
replay_window_secs = 1800
ignore_time_skew = false ignore_time_skew = false
[access.users] [access.users]
@@ -62,15 +67,11 @@ hello = "00000000000000000000000000000000"
# hello = 1073741824 # 1 GB # hello = 1073741824 # 1 GB
# === Upstreams & Routing === # === Upstreams & Routing ===
# By default, direct connection is used, but you can add SOCKS proxy
# Direct - Default
[[upstreams]] [[upstreams]]
type = "direct" type = "direct"
enabled = true enabled = true
weight = 10 weight = 10
# SOCKS5
# [[upstreams]] # [[upstreams]]
# type = "socks5" # type = "socks5"
# address = "127.0.0.1:9050" # address = "127.0.0.1:9050"

300
src/cli.rs Normal file
View File

@@ -0,0 +1,300 @@
//! CLI commands: --init (fire-and-forget setup)
use std::fs;
use std::path::{Path, PathBuf};
use std::process::Command;
use rand::Rng;
/// Options for the init command
pub struct InitOptions {
pub port: u16,
pub domain: String,
pub secret: Option<String>,
pub username: String,
pub config_dir: PathBuf,
pub no_start: bool,
}
impl Default for InitOptions {
fn default() -> Self {
Self {
port: 443,
domain: "www.google.com".to_string(),
secret: None,
username: "user".to_string(),
config_dir: PathBuf::from("/etc/telemt"),
no_start: false,
}
}
}
/// Parse --init subcommand options from CLI args.
///
/// Returns `Some(InitOptions)` if `--init` was found, `None` otherwise.
pub fn parse_init_args(args: &[String]) -> Option<InitOptions> {
if !args.iter().any(|a| a == "--init") {
return None;
}
let mut opts = InitOptions::default();
let mut i = 0;
while i < args.len() {
match args[i].as_str() {
"--port" => {
i += 1;
if i < args.len() {
opts.port = args[i].parse().unwrap_or(443);
}
}
"--domain" => {
i += 1;
if i < args.len() {
opts.domain = args[i].clone();
}
}
"--secret" => {
i += 1;
if i < args.len() {
opts.secret = Some(args[i].clone());
}
}
"--user" => {
i += 1;
if i < args.len() {
opts.username = args[i].clone();
}
}
"--config-dir" => {
i += 1;
if i < args.len() {
opts.config_dir = PathBuf::from(&args[i]);
}
}
"--no-start" => {
opts.no_start = true;
}
_ => {}
}
i += 1;
}
Some(opts)
}
/// Run the fire-and-forget setup.
pub fn run_init(opts: InitOptions) -> Result<(), Box<dyn std::error::Error>> {
eprintln!("[telemt] Fire-and-forget setup");
eprintln!();
// 1. Generate or validate secret
let secret = match opts.secret {
Some(s) => {
if s.len() != 32 || !s.chars().all(|c| c.is_ascii_hexdigit()) {
eprintln!("[error] Secret must be exactly 32 hex characters");
std::process::exit(1);
}
s
}
None => generate_secret(),
};
eprintln!("[+] Secret: {}", secret);
eprintln!("[+] User: {}", opts.username);
eprintln!("[+] Port: {}", opts.port);
eprintln!("[+] Domain: {}", opts.domain);
// 2. Create config directory
fs::create_dir_all(&opts.config_dir)?;
let config_path = opts.config_dir.join("config.toml");
// 3. Write config
let config_content = generate_config(&opts.username, &secret, opts.port, &opts.domain);
fs::write(&config_path, &config_content)?;
eprintln!("[+] Config written to {}", config_path.display());
// 4. Write systemd unit
let exe_path = std::env::current_exe()
.unwrap_or_else(|_| PathBuf::from("/usr/local/bin/telemt"));
let unit_path = Path::new("/etc/systemd/system/telemt.service");
let unit_content = generate_systemd_unit(&exe_path, &config_path);
match fs::write(unit_path, &unit_content) {
Ok(()) => {
eprintln!("[+] Systemd unit written to {}", unit_path.display());
}
Err(e) => {
eprintln!("[!] Cannot write systemd unit (run as root?): {}", e);
eprintln!("[!] Manual unit file content:");
eprintln!("{}", unit_content);
// Still print links and config
print_links(&opts.username, &secret, opts.port, &opts.domain);
return Ok(());
}
}
// 5. Reload systemd
run_cmd("systemctl", &["daemon-reload"]);
// 6. Enable service
run_cmd("systemctl", &["enable", "telemt.service"]);
eprintln!("[+] Service enabled");
// 7. Start service (unless --no-start)
if !opts.no_start {
run_cmd("systemctl", &["start", "telemt.service"]);
eprintln!("[+] Service started");
// Brief delay then check status
std::thread::sleep(std::time::Duration::from_secs(1));
let status = Command::new("systemctl")
.args(["is-active", "telemt.service"])
.output();
match status {
Ok(out) if out.status.success() => {
eprintln!("[+] Service is running");
}
_ => {
eprintln!("[!] Service may not have started correctly");
eprintln!("[!] Check: journalctl -u telemt.service -n 20");
}
}
} else {
eprintln!("[+] Service not started (--no-start)");
eprintln!("[+] Start manually: systemctl start telemt.service");
}
eprintln!();
// 8. Print links
print_links(&opts.username, &secret, opts.port, &opts.domain);
Ok(())
}
fn generate_secret() -> String {
let mut rng = rand::rng();
let bytes: Vec<u8> = (0..16).map(|_| rng.random::<u8>()).collect();
hex::encode(bytes)
}
fn generate_config(username: &str, secret: &str, port: u16, domain: &str) -> String {
format!(
r#"# Telemt MTProxy — auto-generated config
# Re-run `telemt --init` to regenerate
show_link = ["{username}"]
[general]
prefer_ipv6 = false
fast_mode = true
use_middle_proxy = false
log_level = "normal"
[general.modes]
classic = false
secure = false
tls = true
[server]
port = {port}
listen_addr_ipv4 = "0.0.0.0"
listen_addr_ipv6 = "::"
[[server.listeners]]
ip = "0.0.0.0"
[[server.listeners]]
ip = "::"
[timeouts]
client_handshake = 15
tg_connect = 10
client_keepalive = 60
client_ack = 300
[censorship]
tls_domain = "{domain}"
mask = true
mask_port = 443
fake_cert_len = 2048
[access]
replay_check_len = 65536
replay_window_secs = 1800
ignore_time_skew = false
[access.users]
{username} = "{secret}"
[[upstreams]]
type = "direct"
enabled = true
weight = 10
"#,
username = username,
secret = secret,
port = port,
domain = domain,
)
}
fn generate_systemd_unit(exe_path: &Path, config_path: &Path) -> String {
format!(
r#"[Unit]
Description=Telemt MTProxy
Documentation=https://github.com/nicepkg/telemt
After=network-online.target
Wants=network-online.target
[Service]
Type=simple
ExecStart={exe} {config}
Restart=always
RestartSec=5
LimitNOFILE=65535
# Security hardening
NoNewPrivileges=true
ProtectSystem=strict
ProtectHome=true
ReadWritePaths=/etc/telemt
PrivateTmp=true
[Install]
WantedBy=multi-user.target
"#,
exe = exe_path.display(),
config = config_path.display(),
)
}
fn run_cmd(cmd: &str, args: &[&str]) {
match Command::new(cmd).args(args).output() {
Ok(output) => {
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
eprintln!("[!] {} {} failed: {}", cmd, args.join(" "), stderr.trim());
}
}
Err(e) => {
eprintln!("[!] Failed to run {} {}: {}", cmd, args.join(" "), e);
}
}
}
fn print_links(username: &str, secret: &str, port: u16, domain: &str) {
let domain_hex = hex::encode(domain);
println!("=== Proxy Links ===");
println!("[{}]", username);
println!(" EE-TLS: tg://proxy?server=YOUR_SERVER_IP&port={}&secret=ee{}{}",
port, secret, domain_hex);
println!();
println!("Replace YOUR_SERVER_IP with your server's public IP.");
println!("The proxy will auto-detect and display the correct link on startup.");
println!("Check: journalctl -u telemt.service | head -30");
println!("===================");
}

View File

@@ -14,6 +14,7 @@ fn default_port() -> u16 { 443 }
fn default_tls_domain() -> String { "www.google.com".to_string() } fn default_tls_domain() -> String { "www.google.com".to_string() }
fn default_mask_port() -> u16 { 443 } fn default_mask_port() -> u16 { 443 }
fn default_replay_check_len() -> usize { 65536 } fn default_replay_check_len() -> usize { 65536 }
fn default_replay_window_secs() -> u64 { 1800 }
fn default_handshake_timeout() -> u64 { 15 } fn default_handshake_timeout() -> u64 { 15 }
fn default_connect_timeout() -> u64 { 10 } fn default_connect_timeout() -> u64 { 10 }
fn default_keepalive() -> u64 { 60 } fn default_keepalive() -> u64 { 60 }
@@ -28,6 +29,58 @@ fn default_metrics_whitelist() -> Vec<IpAddr> {
] ]
} }
// ============= Log Level =============
/// Logging verbosity level
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum LogLevel {
/// All messages including trace (trace + debug + info + warn + error)
Debug,
/// Detailed operational logs (debug + info + warn + error)
Verbose,
/// Standard operational logs (info + warn + error)
#[default]
Normal,
/// Minimal output: only warnings and errors (warn + error).
/// Proxy links are still printed to stdout via println!.
Silent,
}
impl LogLevel {
/// Convert to tracing EnvFilter directive string
pub fn to_filter_str(&self) -> &'static str {
match self {
LogLevel::Debug => "trace",
LogLevel::Verbose => "debug",
LogLevel::Normal => "info",
LogLevel::Silent => "warn",
}
}
/// Parse from a loose string (CLI argument)
pub fn from_str_loose(s: &str) -> Self {
match s.to_lowercase().as_str() {
"debug" | "trace" => LogLevel::Debug,
"verbose" => LogLevel::Verbose,
"normal" | "info" => LogLevel::Normal,
"silent" | "quiet" | "error" | "warn" => LogLevel::Silent,
_ => LogLevel::Normal,
}
}
}
impl std::fmt::Display for LogLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LogLevel::Debug => write!(f, "debug"),
LogLevel::Verbose => write!(f, "verbose"),
LogLevel::Normal => write!(f, "normal"),
LogLevel::Silent => write!(f, "silent"),
}
}
}
// ============= Sub-Configs ============= // ============= Sub-Configs =============
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -62,6 +115,9 @@ pub struct GeneralConfig {
#[serde(default)] #[serde(default)]
pub ad_tag: Option<String>, pub ad_tag: Option<String>,
#[serde(default)]
pub log_level: LogLevel,
} }
impl Default for GeneralConfig { impl Default for GeneralConfig {
@@ -72,6 +128,7 @@ impl Default for GeneralConfig {
fast_mode: true, fast_mode: true,
use_middle_proxy: false, use_middle_proxy: false,
ad_tag: None, ad_tag: None,
log_level: LogLevel::Normal,
} }
} }
} }
@@ -187,6 +244,9 @@ pub struct AccessConfig {
#[serde(default = "default_replay_check_len")] #[serde(default = "default_replay_check_len")]
pub replay_check_len: usize, pub replay_check_len: usize,
#[serde(default = "default_replay_window_secs")]
pub replay_window_secs: u64,
#[serde(default)] #[serde(default)]
pub ignore_time_skew: bool, pub ignore_time_skew: bool,
} }
@@ -201,6 +261,7 @@ impl Default for AccessConfig {
user_expirations: HashMap::new(), user_expirations: HashMap::new(),
user_data_quota: HashMap::new(), user_data_quota: HashMap::new(),
replay_check_len: default_replay_check_len(), replay_check_len: default_replay_check_len(),
replay_window_secs: default_replay_window_secs(),
ignore_time_skew: false, ignore_time_skew: false,
} }
} }
@@ -299,20 +360,14 @@ impl ProxyConfig {
return Err(ProxyError::Config("tls_domain cannot be empty".to_string())); return Err(ProxyError::Config("tls_domain cannot be empty".to_string()));
} }
// Warn if using default tls_domain
if config.censorship.tls_domain == "www.google.com" {
tracing::warn!("Using default tls_domain (www.google.com). Consider setting a custom domain in config.toml");
}
// Default mask_host to tls_domain if not set // Default mask_host to tls_domain if not set
if config.censorship.mask_host.is_none() { if config.censorship.mask_host.is_none() {
tracing::info!("mask_host not set, using tls_domain ({}) for masking", config.censorship.tls_domain);
config.censorship.mask_host = Some(config.censorship.tls_domain.clone()); config.censorship.mask_host = Some(config.censorship.tls_domain.clone());
} }
// Random fake_cert_len // Random fake_cert_len
use rand::Rng; use rand::Rng;
config.censorship.fake_cert_len = rand::thread_rng().gen_range(1024..4096); config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096);
// Migration: Populate listeners if empty // Migration: Populate listeners if empty
if config.server.listeners.is_empty() { if config.server.listeners.is_empty() {
@@ -353,7 +408,6 @@ impl ProxyConfig {
return Err(ProxyError::Config("No modes enabled".to_string())); return Err(ProxyError::Config("No modes enabled".to_string()));
} }
// Validate tls_domain format (basic check)
if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') { if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') {
return Err(ProxyError::Config( return Err(ProxyError::Config(
format!("Invalid tls_domain: '{}'. Must be a valid domain name", self.censorship.tls_domain) format!("Invalid tls_domain: '{}'. Must be a valid domain name", self.censorship.tls_domain)

View File

@@ -1,9 +1,19 @@
//! AES encryption implementations //! AES encryption implementations
//! //!
//! Provides AES-256-CTR and AES-256-CBC modes for MTProto encryption. //! Provides AES-256-CTR and AES-256-CBC modes for MTProto encryption.
//!
//! ## Zeroize policy
//!
//! - `AesCbc` stores raw key/IV bytes and zeroizes them on drop.
//! - `AesCtr` wraps an opaque `Aes256Ctr` cipher from the `ctr` crate.
//! The expanded key schedule lives inside that type and cannot be
//! zeroized from outside. Callers that hold raw key material (e.g.
//! `HandshakeSuccess`, `ObfuscationParams`) are responsible for
//! zeroizing their own copies.
use aes::Aes256; use aes::Aes256;
use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}}; use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}};
use zeroize::Zeroize;
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
type Aes256Ctr = Ctr128BE<Aes256>; type Aes256Ctr = Ctr128BE<Aes256>;
@@ -12,7 +22,12 @@ type Aes256Ctr = Ctr128BE<Aes256>;
/// AES-256-CTR encryptor/decryptor /// AES-256-CTR encryptor/decryptor
/// ///
/// CTR mode is symmetric - encryption and decryption are the same operation. /// CTR mode is symmetric encryption and decryption are the same operation.
///
/// **Zeroize note:** The inner `Aes256Ctr` cipher state (expanded key schedule
/// + counter) is opaque and cannot be zeroized. If you need to protect key
/// material, zeroize the `[u8; 32]` key and `u128` IV at the call site
/// before dropping them.
pub struct AesCtr { pub struct AesCtr {
cipher: Aes256Ctr, cipher: Aes256Ctr,
} }
@@ -62,14 +77,23 @@ impl AesCtr {
/// AES-256-CBC cipher with proper chaining /// AES-256-CBC cipher with proper chaining
/// ///
/// Unlike CTR mode, CBC is NOT symmetric - encryption and decryption /// Unlike CTR mode, CBC is NOT symmetric encryption and decryption
/// are different operations. This implementation handles CBC chaining /// are different operations. This implementation handles CBC chaining
/// correctly across multiple blocks. /// correctly across multiple blocks.
///
/// Key and IV are zeroized on drop.
pub struct AesCbc { pub struct AesCbc {
key: [u8; 32], key: [u8; 32],
iv: [u8; 16], iv: [u8; 16],
} }
impl Drop for AesCbc {
fn drop(&mut self) {
self.key.zeroize();
self.iv.zeroize();
}
}
impl AesCbc { impl AesCbc {
/// AES block size /// AES block size
const BLOCK_SIZE: usize = 16; const BLOCK_SIZE: usize = 16;
@@ -141,17 +165,9 @@ impl AesCbc {
for chunk in data.chunks(Self::BLOCK_SIZE) { for chunk in data.chunks(Self::BLOCK_SIZE) {
let plaintext: [u8; 16] = chunk.try_into().unwrap(); let plaintext: [u8; 16] = chunk.try_into().unwrap();
// XOR plaintext with previous ciphertext (or IV for first block)
let xored = Self::xor_blocks(&plaintext, &prev_ciphertext); let xored = Self::xor_blocks(&plaintext, &prev_ciphertext);
// Encrypt the XORed block
let ciphertext = self.encrypt_block(&xored, &key_schedule); let ciphertext = self.encrypt_block(&xored, &key_schedule);
// Save for next iteration
prev_ciphertext = ciphertext; prev_ciphertext = ciphertext;
// Append to result
result.extend_from_slice(&ciphertext); result.extend_from_slice(&ciphertext);
} }
@@ -180,17 +196,9 @@ impl AesCbc {
for chunk in data.chunks(Self::BLOCK_SIZE) { for chunk in data.chunks(Self::BLOCK_SIZE) {
let ciphertext: [u8; 16] = chunk.try_into().unwrap(); let ciphertext: [u8; 16] = chunk.try_into().unwrap();
// Decrypt the block
let decrypted = self.decrypt_block(&ciphertext, &key_schedule); let decrypted = self.decrypt_block(&ciphertext, &key_schedule);
// XOR with previous ciphertext (or IV for first block)
let plaintext = Self::xor_blocks(&decrypted, &prev_ciphertext); let plaintext = Self::xor_blocks(&decrypted, &prev_ciphertext);
// Save current ciphertext for next iteration
prev_ciphertext = ciphertext; prev_ciphertext = ciphertext;
// Append to result
result.extend_from_slice(&plaintext); result.extend_from_slice(&plaintext);
} }
@@ -217,16 +225,13 @@ impl AesCbc {
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) { for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
let block = &mut data[i..i + Self::BLOCK_SIZE]; let block = &mut data[i..i + Self::BLOCK_SIZE];
// XOR with previous ciphertext
for j in 0..Self::BLOCK_SIZE { for j in 0..Self::BLOCK_SIZE {
block[j] ^= prev_ciphertext[j]; block[j] ^= prev_ciphertext[j];
} }
// Encrypt in-place
let block_array: &mut [u8; 16] = block.try_into().unwrap(); let block_array: &mut [u8; 16] = block.try_into().unwrap();
*block_array = self.encrypt_block(block_array, &key_schedule); *block_array = self.encrypt_block(block_array, &key_schedule);
// Save for next iteration
prev_ciphertext = *block_array; prev_ciphertext = *block_array;
} }
@@ -248,26 +253,20 @@ impl AesCbc {
use aes::cipher::KeyInit; use aes::cipher::KeyInit;
let key_schedule = aes::Aes256::new((&self.key).into()); let key_schedule = aes::Aes256::new((&self.key).into());
// For in-place decryption, we need to save ciphertext blocks
// before we overwrite them
let mut prev_ciphertext = self.iv; let mut prev_ciphertext = self.iv;
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) { for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
let block = &mut data[i..i + Self::BLOCK_SIZE]; let block = &mut data[i..i + Self::BLOCK_SIZE];
// Save current ciphertext before modifying
let current_ciphertext: [u8; 16] = block.try_into().unwrap(); let current_ciphertext: [u8; 16] = block.try_into().unwrap();
// Decrypt in-place
let block_array: &mut [u8; 16] = block.try_into().unwrap(); let block_array: &mut [u8; 16] = block.try_into().unwrap();
*block_array = self.decrypt_block(block_array, &key_schedule); *block_array = self.decrypt_block(block_array, &key_schedule);
// XOR with previous ciphertext
for j in 0..Self::BLOCK_SIZE { for j in 0..Self::BLOCK_SIZE {
block[j] ^= prev_ciphertext[j]; block[j] ^= prev_ciphertext[j];
} }
// Save for next iteration
prev_ciphertext = current_ciphertext; prev_ciphertext = current_ciphertext;
} }
@@ -347,10 +346,8 @@ mod tests {
let mut cipher = AesCtr::new(&key, iv); let mut cipher = AesCtr::new(&key, iv);
cipher.apply(&mut data); cipher.apply(&mut data);
// Encrypted should be different
assert_ne!(&data[..], original); assert_ne!(&data[..], original);
// Decrypt with fresh cipher
let mut cipher = AesCtr::new(&key, iv); let mut cipher = AesCtr::new(&key, iv);
cipher.apply(&mut data); cipher.apply(&mut data);
@@ -364,7 +361,7 @@ mod tests {
let key = [0u8; 32]; let key = [0u8; 32];
let iv = [0u8; 16]; let iv = [0u8; 16];
let original = [0u8; 32]; // 2 blocks let original = [0u8; 32];
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
let encrypted = cipher.encrypt(&original).unwrap(); let encrypted = cipher.encrypt(&original).unwrap();
@@ -375,31 +372,25 @@ mod tests {
#[test] #[test]
fn test_aes_cbc_chaining_works() { fn test_aes_cbc_chaining_works() {
// This is the key test - verify CBC chaining is correct
let key = [0x42u8; 32]; let key = [0x42u8; 32];
let iv = [0x00u8; 16]; let iv = [0x00u8; 16];
// Two IDENTICAL plaintext blocks
let plaintext = [0xAAu8; 32]; let plaintext = [0xAAu8; 32];
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
let ciphertext = cipher.encrypt(&plaintext).unwrap(); let ciphertext = cipher.encrypt(&plaintext).unwrap();
// With proper CBC, identical plaintext blocks produce DIFFERENT ciphertext
let block1 = &ciphertext[0..16]; let block1 = &ciphertext[0..16];
let block2 = &ciphertext[16..32]; let block2 = &ciphertext[16..32];
assert_ne!( assert_ne!(
block1, block2, block1, block2,
"CBC chaining broken: identical plaintext blocks produced identical ciphertext. \ "CBC chaining broken: identical plaintext blocks produced identical ciphertext"
This indicates ECB mode, not CBC!"
); );
} }
#[test] #[test]
fn test_aes_cbc_known_vector() { fn test_aes_cbc_known_vector() {
// Test with known NIST test vector
// AES-256-CBC with zero key and zero IV
let key = [0u8; 32]; let key = [0u8; 32];
let iv = [0u8; 16]; let iv = [0u8; 16];
let plaintext = [0u8; 16]; let plaintext = [0u8; 16];
@@ -407,11 +398,9 @@ mod tests {
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
let ciphertext = cipher.encrypt(&plaintext).unwrap(); let ciphertext = cipher.encrypt(&plaintext).unwrap();
// Decrypt and verify roundtrip
let decrypted = cipher.decrypt(&ciphertext).unwrap(); let decrypted = cipher.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice()); assert_eq!(plaintext.as_slice(), decrypted.as_slice());
// Ciphertext should not be all zeros
assert_ne!(ciphertext.as_slice(), plaintext.as_slice()); assert_ne!(ciphertext.as_slice(), plaintext.as_slice());
} }
@@ -420,7 +409,6 @@ mod tests {
let key = [0x12u8; 32]; let key = [0x12u8; 32];
let iv = [0x34u8; 16]; let iv = [0x34u8; 16];
// 5 blocks = 80 bytes
let plaintext: Vec<u8> = (0..80).collect(); let plaintext: Vec<u8> = (0..80).collect();
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
@@ -435,7 +423,7 @@ mod tests {
let key = [0x12u8; 32]; let key = [0x12u8; 32];
let iv = [0x34u8; 16]; let iv = [0x34u8; 16];
let original = [0x56u8; 48]; // 3 blocks let original = [0x56u8; 48];
let mut buffer = original; let mut buffer = original;
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
@@ -462,41 +450,33 @@ mod tests {
fn test_aes_cbc_unaligned_error() { fn test_aes_cbc_unaligned_error() {
let cipher = AesCbc::new([0u8; 32], [0u8; 16]); let cipher = AesCbc::new([0u8; 32], [0u8; 16]);
// 15 bytes - not aligned to block size
let result = cipher.encrypt(&[0u8; 15]); let result = cipher.encrypt(&[0u8; 15]);
assert!(result.is_err()); assert!(result.is_err());
// 17 bytes - not aligned
let result = cipher.encrypt(&[0u8; 17]); let result = cipher.encrypt(&[0u8; 17]);
assert!(result.is_err()); assert!(result.is_err());
} }
#[test] #[test]
fn test_aes_cbc_avalanche_effect() { fn test_aes_cbc_avalanche_effect() {
// Changing one bit in plaintext should change entire ciphertext block
// and all subsequent blocks (due to chaining)
let key = [0xAB; 32]; let key = [0xAB; 32];
let iv = [0xCD; 16]; let iv = [0xCD; 16];
let mut plaintext1 = [0u8; 32]; let plaintext1 = [0u8; 32];
let mut plaintext2 = [0u8; 32]; let mut plaintext2 = [0u8; 32];
plaintext2[0] = 0x01; // Single bit difference in first block plaintext2[0] = 0x01;
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
let ciphertext1 = cipher.encrypt(&plaintext1).unwrap(); let ciphertext1 = cipher.encrypt(&plaintext1).unwrap();
let ciphertext2 = cipher.encrypt(&plaintext2).unwrap(); let ciphertext2 = cipher.encrypt(&plaintext2).unwrap();
// First blocks should be different
assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]); assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]);
// Second blocks should ALSO be different (chaining effect)
assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]); assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]);
} }
#[test] #[test]
fn test_aes_cbc_iv_matters() { fn test_aes_cbc_iv_matters() {
// Same plaintext with different IVs should produce different ciphertext
let key = [0x55; 32]; let key = [0x55; 32];
let plaintext = [0x77u8; 16]; let plaintext = [0x77u8; 16];
@@ -511,7 +491,6 @@ mod tests {
#[test] #[test]
fn test_aes_cbc_deterministic() { fn test_aes_cbc_deterministic() {
// Same key, IV, plaintext should always produce same ciphertext
let key = [0x99; 32]; let key = [0x99; 32];
let iv = [0x88; 16]; let iv = [0x88; 16];
let plaintext = [0x77u8; 32]; let plaintext = [0x77u8; 32];
@@ -524,6 +503,23 @@ mod tests {
assert_eq!(ciphertext1, ciphertext2); assert_eq!(ciphertext1, ciphertext2);
} }
// ============= Zeroize Tests =============
#[test]
fn test_aes_cbc_zeroize_on_drop() {
let key = [0xAA; 32];
let iv = [0xBB; 16];
let cipher = AesCbc::new(key, iv);
// Verify key/iv are set
assert_eq!(cipher.key, [0xAA; 32]);
assert_eq!(cipher.iv, [0xBB; 16]);
drop(cipher);
// After drop, key/iv are zeroized (can't observe directly,
// but the Drop impl runs without panic)
}
// ============= Error Handling Tests ============= // ============= Error Handling Tests =============
#[test] #[test]

View File

@@ -1,3 +1,16 @@
//! Cryptographic hash functions
//!
//! ## Protocol-required algorithms
//!
//! This module exposes MD5 and SHA-1 alongside SHA-256. These weaker
//! hash functions are **required by the Telegram Middle Proxy protocol**
//! (`derive_middleproxy_keys`) and cannot be replaced without breaking
//! compatibility. They are NOT used for any security-sensitive purpose
//! outside of that specific key derivation scheme mandated by Telegram.
//!
//! Static analysis tools (CodeQL, cargo-audit) may flag them — the
//! usages are intentional and protocol-mandated.
use hmac::{Hmac, Mac}; use hmac::{Hmac, Mac};
use sha2::Sha256; use sha2::Sha256;
use md5::Md5; use md5::Md5;
@@ -21,14 +34,16 @@ pub fn sha256_hmac(key: &[u8], data: &[u8]) -> [u8; 32] {
mac.finalize().into_bytes().into() mac.finalize().into_bytes().into()
} }
/// SHA-1 /// SHA-1 — **protocol-required** by Telegram Middle Proxy key derivation.
/// Not used for general-purpose hashing.
pub fn sha1(data: &[u8]) -> [u8; 20] { pub fn sha1(data: &[u8]) -> [u8; 20] {
let mut hasher = Sha1::new(); let mut hasher = Sha1::new();
hasher.update(data); hasher.update(data);
hasher.finalize().into() hasher.finalize().into()
} }
/// MD5 /// MD5 — **protocol-required** by Telegram Middle Proxy key derivation.
/// Not used for general-purpose hashing.
pub fn md5(data: &[u8]) -> [u8; 16] { pub fn md5(data: &[u8]) -> [u8; 16] {
let mut hasher = Md5::new(); let mut hasher = Md5::new();
hasher.update(data); hasher.update(data);
@@ -40,7 +55,11 @@ pub fn crc32(data: &[u8]) -> u32 {
crc32fast::hash(data) crc32fast::hash(data)
} }
/// Middle Proxy Keygen /// Middle Proxy key derivation
///
/// Uses MD5 + SHA-1 as mandated by the Telegram Middle Proxy protocol.
/// These algorithms are NOT replaceable here — changing them would break
/// interoperability with Telegram's middle proxy infrastructure.
pub fn derive_middleproxy_keys( pub fn derive_middleproxy_keys(
nonce_srv: &[u8; 16], nonce_srv: &[u8; 16],
nonce_clt: &[u8; 16], nonce_clt: &[u8; 16],

View File

@@ -6,4 +6,4 @@ pub mod random;
pub use aes::{AesCtr, AesCbc}; pub use aes::{AesCtr, AesCbc};
pub use hash::{sha256, sha256_hmac, sha1, md5, crc32}; pub use hash::{sha256, sha256_hmac, sha1, md5, crc32};
pub use random::{SecureRandom, SECURE_RANDOM}; pub use random::SecureRandom;

View File

@@ -3,11 +3,8 @@
use rand::{Rng, RngCore, SeedableRng}; use rand::{Rng, RngCore, SeedableRng};
use rand::rngs::StdRng; use rand::rngs::StdRng;
use parking_lot::Mutex; use parking_lot::Mutex;
use zeroize::Zeroize;
use crate::crypto::AesCtr; use crate::crypto::AesCtr;
use once_cell::sync::Lazy;
/// Global secure random instance
pub static SECURE_RANDOM: Lazy<SecureRandom> = Lazy::new(SecureRandom::new);
/// Cryptographically secure PRNG with AES-CTR /// Cryptographically secure PRNG with AES-CTR
pub struct SecureRandom { pub struct SecureRandom {
@@ -20,18 +17,30 @@ struct SecureRandomInner {
buffer: Vec<u8>, buffer: Vec<u8>,
} }
impl Drop for SecureRandomInner {
fn drop(&mut self) {
self.buffer.zeroize();
}
}
impl SecureRandom { impl SecureRandom {
pub fn new() -> Self { pub fn new() -> Self {
let mut rng = StdRng::from_entropy(); let mut seed_source = rand::rng();
let mut rng = StdRng::from_rng(&mut seed_source);
let mut key = [0u8; 32]; let mut key = [0u8; 32];
rng.fill_bytes(&mut key); rng.fill_bytes(&mut key);
let iv: u128 = rng.gen(); let iv: u128 = rng.random();
let cipher = AesCtr::new(&key, iv);
// Zeroize local key copy — cipher already consumed it
key.zeroize();
Self { Self {
inner: Mutex::new(SecureRandomInner { inner: Mutex::new(SecureRandomInner {
rng, rng,
cipher: AesCtr::new(&key, iv), cipher,
buffer: Vec::with_capacity(1024), buffer: Vec::with_capacity(1024),
}), }),
} }
@@ -78,7 +87,6 @@ impl SecureRandom {
result |= (b as u64) << (i * 8); result |= (b as u64) << (i * 8);
} }
// Mask extra bits
if k < 64 { if k < 64 {
result &= (1u64 << k) - 1; result &= (1u64 << k) - 1;
} }
@@ -107,13 +115,13 @@ impl SecureRandom {
/// Generate random u32 /// Generate random u32
pub fn u32(&self) -> u32 { pub fn u32(&self) -> u32 {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
inner.rng.gen() inner.rng.random()
} }
/// Generate random u64 /// Generate random u64
pub fn u64(&self) -> u64 { pub fn u64(&self) -> u64 {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
inner.rng.gen() inner.rng.random()
} }
} }
@@ -162,12 +170,10 @@ mod tests {
fn test_bits() { fn test_bits() {
let rng = SecureRandom::new(); let rng = SecureRandom::new();
// Single bit should be 0 or 1
for _ in 0..100 { for _ in 0..100 {
assert!(rng.bits(1) <= 1); assert!(rng.bits(1) <= 1);
} }
// 8 bits should be 0-255
for _ in 0..100 { for _ in 0..100 {
assert!(rng.bits(8) <= 255); assert!(rng.bits(8) <= 255);
} }
@@ -185,10 +191,8 @@ mod tests {
} }
} }
// Should have seen all items
assert_eq!(seen.len(), 5); assert_eq!(seen.len(), 5);
// Empty slice should return None
let empty: Vec<i32> = vec![]; let empty: Vec<i32> = vec![];
assert!(rng.choose(&empty).is_none()); assert!(rng.choose(&empty).is_none());
} }
@@ -201,12 +205,10 @@ mod tests {
let mut shuffled = original.clone(); let mut shuffled = original.clone();
rng.shuffle(&mut shuffled); rng.shuffle(&mut shuffled);
// Should contain same elements
let mut sorted = shuffled.clone(); let mut sorted = shuffled.clone();
sorted.sort(); sorted.sort();
assert_eq!(sorted, original); assert_eq!(sorted, original);
// Should be different order (with very high probability)
assert_ne!(shuffled, original); assert_ne!(shuffled, original);
} }
} }

View File

@@ -118,16 +118,13 @@ pub trait Recoverable {
impl Recoverable for StreamError { impl Recoverable for StreamError {
fn is_recoverable(&self) -> bool { fn is_recoverable(&self) -> bool {
match self { match self {
// Partial operations can be retried
Self::PartialRead { .. } | Self::PartialWrite { .. } => true, Self::PartialRead { .. } | Self::PartialWrite { .. } => true,
// I/O errors depend on kind
Self::Io(e) => matches!( Self::Io(e) => matches!(
e.kind(), e.kind(),
std::io::ErrorKind::WouldBlock std::io::ErrorKind::WouldBlock
| std::io::ErrorKind::Interrupted | std::io::ErrorKind::Interrupted
| std::io::ErrorKind::TimedOut | std::io::ErrorKind::TimedOut
), ),
// These are not recoverable
Self::Poisoned { .. } Self::Poisoned { .. }
| Self::BufferOverflow { .. } | Self::BufferOverflow { .. }
| Self::InvalidFrame { .. } | Self::InvalidFrame { .. }
@@ -137,13 +134,9 @@ impl Recoverable for StreamError {
fn can_continue(&self) -> bool { fn can_continue(&self) -> bool {
match self { match self {
// Poisoned stream cannot be used
Self::Poisoned { .. } => false, Self::Poisoned { .. } => false,
// EOF means stream is done
Self::UnexpectedEof => false, Self::UnexpectedEof => false,
// Buffer overflow is fatal
Self::BufferOverflow { .. } => false, Self::BufferOverflow { .. } => false,
// Others might allow continuation
_ => true, _ => true,
} }
} }
@@ -383,18 +376,18 @@ mod tests {
#[test] #[test]
fn test_handshake_result() { fn test_handshake_result() {
let success: HandshakeResult<i32> = HandshakeResult::Success(42); let success: HandshakeResult<i32, (), ()> = HandshakeResult::Success(42);
assert!(success.is_success()); assert!(success.is_success());
assert!(!success.is_bad_client()); assert!(!success.is_bad_client());
let bad: HandshakeResult<i32> = HandshakeResult::BadClient; let bad: HandshakeResult<i32, (), ()> = HandshakeResult::BadClient { reader: (), writer: () };
assert!(!bad.is_success()); assert!(!bad.is_success());
assert!(bad.is_bad_client()); assert!(bad.is_bad_client());
} }
#[test] #[test]
fn test_handshake_result_map() { fn test_handshake_result_map() {
let success: HandshakeResult<i32> = HandshakeResult::Success(42); let success: HandshakeResult<i32, (), ()> = HandshakeResult::Success(42);
let mapped = success.map(|x| x * 2); let mapped = success.map(|x| x * 2);
match mapped { match mapped {

View File

@@ -5,9 +5,10 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::signal; use tokio::signal;
use tracing::{info, error, warn}; use tracing::{info, error, warn, debug};
use tracing_subscriber::{fmt, EnvFilter}; use tracing_subscriber::{fmt, EnvFilter};
mod cli;
mod config; mod config;
mod crypto; mod crypto;
mod error; mod error;
@@ -18,78 +19,168 @@ mod stream;
mod transport; mod transport;
mod util; mod util;
use crate::config::ProxyConfig; use crate::config::{ProxyConfig, LogLevel};
use crate::proxy::ClientHandler; use crate::proxy::ClientHandler;
use crate::stats::{Stats, ReplayChecker}; use crate::stats::{Stats, ReplayChecker};
use crate::crypto::SecureRandom;
use crate::transport::{create_listener, ListenOptions, UpstreamManager}; use crate::transport::{create_listener, ListenOptions, UpstreamManager};
use crate::util::ip::detect_ip; use crate::util::ip::detect_ip;
use crate::stream::BufferPool; use crate::stream::BufferPool;
fn parse_cli() -> (String, bool, Option<String>) {
let mut config_path = "config.toml".to_string();
let mut silent = false;
let mut log_level: Option<String> = None;
let args: Vec<String> = std::env::args().skip(1).collect();
// Check for --init first (handled before tokio)
if let Some(init_opts) = cli::parse_init_args(&args) {
if let Err(e) = cli::run_init(init_opts) {
eprintln!("[telemt] Init failed: {}", e);
std::process::exit(1);
}
std::process::exit(0);
}
let mut i = 0;
while i < args.len() {
match args[i].as_str() {
"--silent" | "-s" => { silent = true; }
"--log-level" => {
i += 1;
if i < args.len() { log_level = Some(args[i].clone()); }
}
s if s.starts_with("--log-level=") => {
log_level = Some(s.trim_start_matches("--log-level=").to_string());
}
"--help" | "-h" => {
eprintln!("Usage: telemt [config.toml] [OPTIONS]");
eprintln!();
eprintln!("Options:");
eprintln!(" --silent, -s Suppress info logs");
eprintln!(" --log-level <LEVEL> debug|verbose|normal|silent");
eprintln!(" --help, -h Show this help");
eprintln!();
eprintln!("Setup (fire-and-forget):");
eprintln!(" --init Generate config, install systemd service, start");
eprintln!(" --port <PORT> Listen port (default: 443)");
eprintln!(" --domain <DOMAIN> TLS domain for masking (default: www.google.com)");
eprintln!(" --secret <HEX> 32-char hex secret (auto-generated if omitted)");
eprintln!(" --user <NAME> Username (default: user)");
eprintln!(" --config-dir <DIR> Config directory (default: /etc/telemt)");
eprintln!(" --no-start Don't start the service after install");
std::process::exit(0);
}
s if !s.starts_with('-') => { config_path = s.to_string(); }
other => { eprintln!("Unknown option: {}", other); }
}
i += 1;
}
(config_path, silent, log_level)
}
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Initialize logging let (config_path, cli_silent, cli_log_level) = parse_cli();
fmt()
.with_env_filter(EnvFilter::from_default_env().add_directive("info".parse().unwrap()))
.init();
// Load config
let config_path = std::env::args().nth(1).unwrap_or_else(|| "config.toml".to_string());
let config = match ProxyConfig::load(&config_path) { let config = match ProxyConfig::load(&config_path) {
Ok(c) => c, Ok(c) => c,
Err(e) => { Err(e) => {
// If config doesn't exist, try to create default
if std::path::Path::new(&config_path).exists() { if std::path::Path::new(&config_path).exists() {
error!("Failed to load config: {}", e); eprintln!("[telemt] Error: {}", e);
std::process::exit(1); std::process::exit(1);
} else { } else {
let default = ProxyConfig::default(); let default = ProxyConfig::default();
let toml = toml::to_string_pretty(&default).unwrap(); std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap();
std::fs::write(&config_path, toml).unwrap(); eprintln!("[telemt] Created default config at {}", config_path);
info!("Created default config at {}", config_path);
default default
} }
} }
}; };
config.validate()?; if let Err(e) = config.validate() {
eprintln!("[telemt] Invalid config: {}", e);
std::process::exit(1);
}
// Log loaded configuration for debugging let effective_log_level = if cli_silent {
info!("=== Configuration Loaded ==="); LogLevel::Silent
info!("TLS Domain: {}", config.censorship.tls_domain); } else if let Some(ref s) = cli_log_level {
info!("Mask enabled: {}", config.censorship.mask); LogLevel::from_str_loose(s)
info!("Mask host: {}", config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain)); } else {
info!("Mask port: {}", config.censorship.mask_port); config.general.log_level.clone()
info!("Modes: classic={}, secure={}, tls={}", };
let filter = if std::env::var("RUST_LOG").is_ok() {
EnvFilter::from_default_env()
} else {
EnvFilter::new(effective_log_level.to_filter_str())
};
fmt().with_env_filter(filter).init();
info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION"));
info!("Log level: {}", effective_log_level);
info!("Modes: classic={} secure={} tls={}",
config.general.modes.classic, config.general.modes.classic,
config.general.modes.secure, config.general.modes.secure,
config.general.modes.tls config.general.modes.tls);
); info!("TLS domain: {}", config.censorship.tls_domain);
info!("============================"); info!("Mask: {} -> {}:{}",
config.censorship.mask,
config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain),
config.censorship.mask_port);
if config.censorship.tls_domain == "www.google.com" {
warn!("Using default tls_domain. Consider setting a custom domain.");
}
let prefer_ipv6 = config.general.prefer_ipv6;
let config = Arc::new(config); let config = Arc::new(config);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let rng = Arc::new(SecureRandom::new());
// Initialize global ReplayChecker let replay_checker = Arc::new(ReplayChecker::new(
// Using sharded implementation for better concurrency config.access.replay_check_len,
let replay_checker = Arc::new(ReplayChecker::new(config.access.replay_check_len)); Duration::from_secs(config.access.replay_window_secs),
));
// Initialize Upstream Manager
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone())); let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
// Initialize Buffer Pool
// 16KB buffers, max 4096 buffers (~64MB total cached)
let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096)); let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096));
// Start Health Checks // Startup DC ping
println!("=== Telegram DC Connectivity ===");
let ping_results = upstream_manager.ping_all_dcs(prefer_ipv6).await;
for upstream_result in &ping_results {
println!(" via {}", upstream_result.upstream_name);
for dc in &upstream_result.results {
match (&dc.rtt_ms, &dc.error) {
(Some(rtt), _) => {
println!(" DC{} ({:>21}): {:.0}ms", dc.dc_idx, dc.dc_addr, rtt);
}
(None, Some(err)) => {
println!(" DC{} ({:>21}): FAIL ({})", dc.dc_idx, dc.dc_addr, err);
}
_ => {
println!(" DC{} ({:>21}): FAIL", dc.dc_idx, dc.dc_addr);
}
}
}
}
println!("================================");
// Background tasks
let um_clone = upstream_manager.clone(); let um_clone = upstream_manager.clone();
tokio::spawn(async move { tokio::spawn(async move { um_clone.run_health_checks(prefer_ipv6).await; });
um_clone.run_health_checks().await;
}); let rc_clone = replay_checker.clone();
tokio::spawn(async move { rc_clone.run_periodic_cleanup().await; });
// Detect public IP if needed (once at startup)
let detected_ip = detect_ip().await; let detected_ip = detect_ip().await;
debug!("Detected IPs: v4={:?} v6={:?}", detected_ip.ipv4, detected_ip.ipv6);
// Start Listeners
let mut listeners = Vec::new(); let mut listeners = Vec::new();
for listener_conf in &config.server.listeners { for listener_conf in &config.server.listeners {
@@ -104,7 +195,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::from_std(socket.into())?; let listener = TcpListener::from_std(socket.into())?;
info!("Listening on {}", addr); info!("Listening on {}", addr);
// Determine public IP for tg:// links
let public_ip = if let Some(ip) = listener_conf.announce_ip { let public_ip = if let Some(ip) = listener_conf.announce_ip {
ip ip
} else if listener_conf.ip.is_unspecified() { } else if listener_conf.ip.is_unspecified() {
@@ -117,33 +207,29 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
listener_conf.ip listener_conf.ip
}; };
// Show links for configured users
if !config.show_link.is_empty() { if !config.show_link.is_empty() {
info!("--- Proxy Links for {} ---", public_ip); println!("--- Proxy Links ({}) ---", public_ip);
for user_name in &config.show_link { for user_name in &config.show_link {
if let Some(secret) = config.access.users.get(user_name) { if let Some(secret) = config.access.users.get(user_name) {
info!("User: {}", user_name); println!("[{}]", user_name);
if config.general.modes.classic { if config.general.modes.classic {
info!(" Classic: tg://proxy?server={}&port={}&secret={}", println!(" Classic: tg://proxy?server={}&port={}&secret={}",
public_ip, config.server.port, secret); public_ip, config.server.port, secret);
} }
if config.general.modes.secure { if config.general.modes.secure {
info!(" DD: tg://proxy?server={}&port={}&secret=dd{}", println!(" DD: tg://proxy?server={}&port={}&secret=dd{}",
public_ip, config.server.port, secret); public_ip, config.server.port, secret);
} }
if config.general.modes.tls { if config.general.modes.tls {
let domain_hex = hex::encode(&config.censorship.tls_domain); let domain_hex = hex::encode(&config.censorship.tls_domain);
info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", println!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
public_ip, config.server.port, secret, domain_hex); public_ip, config.server.port, secret, domain_hex);
} }
} else { } else {
warn!("User '{}' specified in show_link not found in users list", user_name); warn!("User '{}' in show_link not found", user_name);
} }
} }
info!("-----------------------------------"); println!("------------------------");
} }
listeners.push(listener); listeners.push(listener);
@@ -155,17 +241,17 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
} }
if listeners.is_empty() { if listeners.is_empty() {
error!("No listeners could be started. Exiting."); error!("No listeners. Exiting.");
std::process::exit(1); std::process::exit(1);
} }
// Accept loop
for listener in listeners { for listener in listeners {
let config = config.clone(); let config = config.clone();
let stats = stats.clone(); let stats = stats.clone();
let upstream_manager = upstream_manager.clone(); let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone(); let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone(); let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
@@ -176,18 +262,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let upstream_manager = upstream_manager.clone(); let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone(); let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone(); let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = ClientHandler::new( if let Err(e) = ClientHandler::new(
stream, stream, peer_addr, config, stats,
peer_addr, upstream_manager, replay_checker, buffer_pool, rng
config,
stats,
upstream_manager,
replay_checker,
buffer_pool
).run().await { ).run().await {
// Log only relevant errors debug!(peer = %peer_addr, error = %e, "Connection error");
} }
}); });
} }
@@ -200,7 +282,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}); });
} }
// Wait for signal
match signal::ctrl_c().await { match signal::ctrl_c().await {
Ok(()) => info!("Shutting down..."), Ok(()) => info!("Shutting down..."),
Err(e) => error!("Signal error: {}", e), Err(e) => error!("Signal error: {}", e),

View File

@@ -1,13 +1,13 @@
//! Protocol constants and datacenter addresses //! Protocol constants and datacenter addresses
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use once_cell::sync::Lazy; use std::sync::LazyLock;
// ============= Telegram Datacenters ============= // ============= Telegram Datacenters =============
pub const TG_DATACENTER_PORT: u16 = 443; pub const TG_DATACENTER_PORT: u16 = 443;
pub static TG_DATACENTERS_V4: Lazy<Vec<IpAddr>> = Lazy::new(|| { pub static TG_DATACENTERS_V4: LazyLock<Vec<IpAddr>> = LazyLock::new(|| {
vec![ vec![
IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)),
IpAddr::V4(Ipv4Addr::new(149, 154, 167, 51)), IpAddr::V4(Ipv4Addr::new(149, 154, 167, 51)),
@@ -17,7 +17,7 @@ pub static TG_DATACENTERS_V4: Lazy<Vec<IpAddr>> = Lazy::new(|| {
] ]
}); });
pub static TG_DATACENTERS_V6: Lazy<Vec<IpAddr>> = Lazy::new(|| { pub static TG_DATACENTERS_V6: LazyLock<Vec<IpAddr>> = LazyLock::new(|| {
vec![ vec![
IpAddr::V6("2001:b28:f23d:f001::a".parse().unwrap()), IpAddr::V6("2001:b28:f23d:f001::a".parse().unwrap()),
IpAddr::V6("2001:67c:04e8:f002::a".parse().unwrap()), IpAddr::V6("2001:67c:04e8:f002::a".parse().unwrap()),
@@ -29,8 +29,8 @@ pub static TG_DATACENTERS_V6: Lazy<Vec<IpAddr>> = Lazy::new(|| {
// ============= Middle Proxies (for advertising) ============= // ============= Middle Proxies (for advertising) =============
pub static TG_MIDDLE_PROXIES_V4: Lazy<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> = pub static TG_MIDDLE_PROXIES_V4: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
Lazy::new(|| { LazyLock::new(|| {
let mut m = std::collections::HashMap::new(); let mut m = std::collections::HashMap::new();
m.insert(1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]); m.insert(1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
m.insert(-1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]); m.insert(-1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
@@ -45,8 +45,8 @@ pub static TG_MIDDLE_PROXIES_V4: Lazy<std::collections::HashMap<i32, Vec<(IpAddr
m m
}); });
pub static TG_MIDDLE_PROXIES_V6: Lazy<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> = pub static TG_MIDDLE_PROXIES_V6: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
Lazy::new(|| { LazyLock::new(|| {
let mut m = std::collections::HashMap::new(); let mut m = std::collections::HashMap::new();
m.insert(1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]); m.insert(1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
m.insert(-1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]); m.insert(-1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
@@ -167,8 +167,6 @@ pub const DEFAULT_ACK_TIMEOUT_SECS: u64 = 300;
// ============= Buffer Sizes ============= // ============= Buffer Sizes =============
/// Default buffer size /// Default buffer size
/// CHANGED: Reduced from 64KB to 16KB to match TLS record size and align with
/// the new buffering strategy for better iOS upload performance.
pub const DEFAULT_BUFFER_SIZE: usize = 16384; pub const DEFAULT_BUFFER_SIZE: usize = 16384;
/// Small buffer size for bad client handling /// Small buffer size for bad client handling

View File

@@ -1,10 +1,13 @@
//! MTProto Obfuscation //! MTProto Obfuscation
use zeroize::Zeroize;
use crate::crypto::{sha256, AesCtr}; use crate::crypto::{sha256, AesCtr};
use crate::error::Result; use crate::error::Result;
use super::constants::*; use super::constants::*;
/// Obfuscation parameters from handshake /// Obfuscation parameters from handshake
///
/// Key material is zeroized on drop.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ObfuscationParams { pub struct ObfuscationParams {
/// Key for decrypting client -> proxy traffic /// Key for decrypting client -> proxy traffic
@@ -21,25 +24,31 @@ pub struct ObfuscationParams {
pub dc_idx: i16, pub dc_idx: i16,
} }
impl Drop for ObfuscationParams {
fn drop(&mut self) {
self.decrypt_key.zeroize();
self.decrypt_iv.zeroize();
self.encrypt_key.zeroize();
self.encrypt_iv.zeroize();
}
}
impl ObfuscationParams { impl ObfuscationParams {
/// Parse obfuscation parameters from handshake bytes /// Parse obfuscation parameters from handshake bytes
/// Returns None if handshake doesn't match any user secret /// Returns None if handshake doesn't match any user secret
pub fn from_handshake( pub fn from_handshake(
handshake: &[u8; HANDSHAKE_LEN], handshake: &[u8; HANDSHAKE_LEN],
secrets: &[(String, Vec<u8>)], // (username, secret_bytes) secrets: &[(String, Vec<u8>)],
) -> Option<(Self, String)> { ) -> Option<(Self, String)> {
// Extract prekey and IV for decryption
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
// Reversed for encryption direction
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect(); let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
for (username, secret) in secrets { for (username, secret) in secrets {
// Derive decryption key
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey); dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(secret); dec_key_input.extend_from_slice(secret);
@@ -47,26 +56,22 @@ impl ObfuscationParams {
let decrypt_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap()); let decrypt_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
// Create decryptor and decrypt handshake
let mut decryptor = AesCtr::new(&decrypt_key, decrypt_iv); let mut decryptor = AesCtr::new(&decrypt_key, decrypt_iv);
let decrypted = decryptor.decrypt(handshake); let decrypted = decryptor.decrypt(handshake);
// Check protocol tag
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
.try_into() .try_into()
.unwrap(); .unwrap();
let proto_tag = match ProtoTag::from_bytes(tag_bytes) { let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
Some(tag) => tag, Some(tag) => tag,
None => continue, // Try next secret None => continue,
}; };
// Extract DC index
let dc_idx = i16::from_le_bytes( let dc_idx = i16::from_le_bytes(
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap() decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
); );
// Derive encryption key
let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
enc_key_input.extend_from_slice(enc_prekey); enc_key_input.extend_from_slice(enc_prekey);
enc_key_input.extend_from_slice(secret); enc_key_input.extend_from_slice(secret);
@@ -123,18 +128,15 @@ pub fn generate_nonce<R: FnMut(usize) -> Vec<u8>>(mut random_bytes: R) -> [u8; H
/// Check if nonce is valid (not matching reserved patterns) /// Check if nonce is valid (not matching reserved patterns)
pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool { pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool {
// Check first byte
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) {
return false; return false;
} }
// Check first 4 bytes
let first_four: [u8; 4] = nonce[..4].try_into().unwrap(); let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { if RESERVED_NONCE_BEGINNINGS.contains(&first_four) {
return false; return false;
} }
// Check bytes 4-7
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap(); let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
if RESERVED_NONCE_CONTINUES.contains(&continue_four) { if RESERVED_NONCE_CONTINUES.contains(&continue_four) {
return false; return false;
@@ -147,12 +149,10 @@ pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool {
pub fn prepare_tg_nonce( pub fn prepare_tg_nonce(
nonce: &mut [u8; HANDSHAKE_LEN], nonce: &mut [u8; HANDSHAKE_LEN],
proto_tag: ProtoTag, proto_tag: ProtoTag,
enc_key_iv: Option<&[u8]>, // For fast mode enc_key_iv: Option<&[u8]>,
) { ) {
// Set protocol tag
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
// For fast mode, copy the reversed enc_key_iv
if let Some(key_iv) = enc_key_iv { if let Some(key_iv) = enc_key_iv {
let reversed: Vec<u8> = key_iv.iter().rev().copied().collect(); let reversed: Vec<u8> = key_iv.iter().rev().copied().collect();
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&reversed); nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&reversed);
@@ -161,14 +161,12 @@ pub fn prepare_tg_nonce(
/// Encrypt the outgoing nonce for Telegram /// Encrypt the outgoing nonce for Telegram
pub fn encrypt_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> { pub fn encrypt_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
// Derive encryption key from the nonce itself
let key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let enc_key = sha256(key_iv); let enc_key = sha256(key_iv);
let enc_iv = u128::from_be_bytes(key_iv[..IV_LEN].try_into().unwrap()); let enc_iv = u128::from_be_bytes(key_iv[..IV_LEN].try_into().unwrap());
let mut encryptor = AesCtr::new(&enc_key, enc_iv); let mut encryptor = AesCtr::new(&enc_key, enc_iv);
// Only encrypt from PROTO_TAG_POS onwards
let mut result = nonce.to_vec(); let mut result = nonce.to_vec();
let encrypted_part = encryptor.encrypt(&nonce[PROTO_TAG_POS..]); let encrypted_part = encryptor.encrypt(&nonce[PROTO_TAG_POS..]);
result[PROTO_TAG_POS..].copy_from_slice(&encrypted_part); result[PROTO_TAG_POS..].copy_from_slice(&encrypted_part);
@@ -182,22 +180,18 @@ mod tests {
#[test] #[test]
fn test_is_valid_nonce() { fn test_is_valid_nonce() {
// Valid nonce
let mut valid = [0x42u8; HANDSHAKE_LEN]; let mut valid = [0x42u8; HANDSHAKE_LEN];
valid[4..8].copy_from_slice(&[1, 2, 3, 4]); valid[4..8].copy_from_slice(&[1, 2, 3, 4]);
assert!(is_valid_nonce(&valid)); assert!(is_valid_nonce(&valid));
// Invalid: starts with 0xef
let mut invalid = [0x00u8; HANDSHAKE_LEN]; let mut invalid = [0x00u8; HANDSHAKE_LEN];
invalid[0] = 0xef; invalid[0] = 0xef;
assert!(!is_valid_nonce(&invalid)); assert!(!is_valid_nonce(&invalid));
// Invalid: starts with HEAD
let mut invalid = [0x00u8; HANDSHAKE_LEN]; let mut invalid = [0x00u8; HANDSHAKE_LEN];
invalid[..4].copy_from_slice(b"HEAD"); invalid[..4].copy_from_slice(b"HEAD");
assert!(!is_valid_nonce(&invalid)); assert!(!is_valid_nonce(&invalid));
// Invalid: bytes 4-7 are zeros
let mut invalid = [0x42u8; HANDSHAKE_LEN]; let mut invalid = [0x42u8; HANDSHAKE_LEN];
invalid[4..8].copy_from_slice(&[0, 0, 0, 0]); invalid[4..8].copy_from_slice(&[0, 0, 0, 0]);
assert!(!is_valid_nonce(&invalid)); assert!(!is_valid_nonce(&invalid));

View File

@@ -4,7 +4,7 @@
//! for domain fronting. The handshake looks like valid TLS 1.3 but //! for domain fronting. The handshake looks like valid TLS 1.3 but
//! actually carries MTProto authentication data. //! actually carries MTProto authentication data.
use crate::crypto::{sha256_hmac, random::SECURE_RANDOM}; use crate::crypto::{sha256_hmac, SecureRandom};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use super::constants::*; use super::constants::*;
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{SystemTime, UNIX_EPOCH};
@@ -315,8 +315,8 @@ pub fn validate_tls_handshake(
/// ///
/// This generates random bytes that look like a valid X25519 public key. /// This generates random bytes that look like a valid X25519 public key.
/// Since we're not doing real TLS, the actual cryptographic properties don't matter. /// Since we're not doing real TLS, the actual cryptographic properties don't matter.
pub fn gen_fake_x25519_key() -> [u8; 32] { pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] {
let bytes = SECURE_RANDOM.bytes(32); let bytes = rng.bytes(32);
bytes.try_into().unwrap() bytes.try_into().unwrap()
} }
@@ -333,8 +333,9 @@ pub fn build_server_hello(
client_digest: &[u8; TLS_DIGEST_LEN], client_digest: &[u8; TLS_DIGEST_LEN],
session_id: &[u8], session_id: &[u8],
fake_cert_len: usize, fake_cert_len: usize,
rng: &SecureRandom,
) -> Vec<u8> { ) -> Vec<u8> {
let x25519_key = gen_fake_x25519_key(); let x25519_key = gen_fake_x25519_key(rng);
// Build ServerHello // Build ServerHello
let server_hello = ServerHelloBuilder::new(session_id.to_vec()) let server_hello = ServerHelloBuilder::new(session_id.to_vec())
@@ -351,7 +352,7 @@ pub fn build_server_hello(
]; ];
// Build fake certificate (Application Data record) // Build fake certificate (Application Data record)
let fake_cert = SECURE_RANDOM.bytes(fake_cert_len); let fake_cert = rng.bytes(fake_cert_len);
let mut app_data_record = Vec::with_capacity(5 + fake_cert_len); let mut app_data_record = Vec::with_capacity(5 + fake_cert_len);
app_data_record.push(TLS_RECORD_APPLICATION); app_data_record.push(TLS_RECORD_APPLICATION);
app_data_record.extend_from_slice(&TLS_VERSION); app_data_record.extend_from_slice(&TLS_VERSION);
@@ -489,8 +490,9 @@ mod tests {
#[test] #[test]
fn test_gen_fake_x25519_key() { fn test_gen_fake_x25519_key() {
let key1 = gen_fake_x25519_key(); let rng = SecureRandom::new();
let key2 = gen_fake_x25519_key(); let key1 = gen_fake_x25519_key(&rng);
let key2 = gen_fake_x25519_key(&rng);
assert_eq!(key1.len(), 32); assert_eq!(key1.len(), 32);
assert_eq!(key2.len(), 32); assert_eq!(key2.len(), 32);
@@ -545,7 +547,8 @@ mod tests {
let client_digest = [0x42u8; 32]; let client_digest = [0x42u8; 32];
let session_id = vec![0xAA; 32]; let session_id = vec![0xAA; 32];
let response = build_server_hello(secret, &client_digest, &session_id, 2048); let rng = SecureRandom::new();
let response = build_server_hello(secret, &client_digest, &session_id, 2048, &rng);
// Should have at least 3 records // Should have at least 3 records
assert!(response.len() > 100); assert!(response.len() > 100);
@@ -577,8 +580,9 @@ mod tests {
let client_digest = [0x42u8; 32]; let client_digest = [0x42u8; 32];
let session_id = vec![0xAA; 32]; let session_id = vec![0xAA; 32];
let response1 = build_server_hello(secret, &client_digest, &session_id, 1024); let rng = SecureRandom::new();
let response2 = build_server_hello(secret, &client_digest, &session_id, 1024); let response1 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng);
let response2 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng);
// Digest position should have non-zero data // Digest position should have non-zero data
let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN];

View File

@@ -15,9 +15,8 @@ use crate::protocol::tls;
use crate::stats::{Stats, ReplayChecker}; use crate::stats::{Stats, ReplayChecker};
use crate::transport::{configure_client_socket, UpstreamManager}; use crate::transport::{configure_client_socket, UpstreamManager};
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter, BufferPool}; use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter, BufferPool};
use crate::crypto::AesCtr; use crate::crypto::{AesCtr, SecureRandom};
// Use absolute paths to avoid confusion
use crate::proxy::handshake::{ use crate::proxy::handshake::{
handle_tls_handshake, handle_mtproto_handshake, handle_tls_handshake, handle_mtproto_handshake,
HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce, HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce,
@@ -25,10 +24,8 @@ use crate::proxy::handshake::{
use crate::proxy::relay::relay_bidirectional; use crate::proxy::relay::relay_bidirectional;
use crate::proxy::masking::handle_bad_client; use crate::proxy::masking::handle_bad_client;
/// Client connection handler (builder struct)
pub struct ClientHandler; pub struct ClientHandler;
/// Running client handler with stream and context
pub struct RunningClientHandler { pub struct RunningClientHandler {
stream: TcpStream, stream: TcpStream,
peer: SocketAddr, peer: SocketAddr,
@@ -37,10 +34,10 @@ pub struct RunningClientHandler {
replay_checker: Arc<ReplayChecker>, replay_checker: Arc<ReplayChecker>,
upstream_manager: Arc<UpstreamManager>, upstream_manager: Arc<UpstreamManager>,
buffer_pool: Arc<BufferPool>, buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
} }
impl ClientHandler { impl ClientHandler {
/// Create new client handler instance
pub fn new( pub fn new(
stream: TcpStream, stream: TcpStream,
peer: SocketAddr, peer: SocketAddr,
@@ -49,28 +46,22 @@ impl ClientHandler {
upstream_manager: Arc<UpstreamManager>, upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>, replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>, buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
) -> RunningClientHandler { ) -> RunningClientHandler {
RunningClientHandler { RunningClientHandler {
stream, stream, peer, config, stats, replay_checker,
peer, upstream_manager, buffer_pool, rng,
config,
stats,
replay_checker,
upstream_manager,
buffer_pool,
} }
} }
} }
impl RunningClientHandler { impl RunningClientHandler {
/// Run the client handler
pub async fn run(mut self) -> Result<()> { pub async fn run(mut self) -> Result<()> {
self.stats.increment_connects_all(); self.stats.increment_connects_all();
let peer = self.peer; let peer = self.peer;
debug!(peer = %peer, "New connection"); debug!(peer = %peer, "New connection");
// Configure socket
if let Err(e) = configure_client_socket( if let Err(e) = configure_client_socket(
&self.stream, &self.stream,
self.config.timeouts.client_keepalive, self.config.timeouts.client_keepalive,
@@ -79,16 +70,10 @@ impl RunningClientHandler {
debug!(peer = %peer, error = %e, "Failed to configure client socket"); debug!(peer = %peer, error = %e, "Failed to configure client socket");
} }
// Perform handshake with timeout
let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake); let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake);
// Clone stats for error handling block
let stats = self.stats.clone(); let stats = self.stats.clone();
let result = timeout( let result = timeout(handshake_timeout, self.do_handshake()).await;
handshake_timeout,
self.do_handshake()
).await;
match result { match result {
Ok(Ok(())) => { Ok(Ok(())) => {
@@ -107,16 +92,14 @@ impl RunningClientHandler {
} }
} }
/// Perform handshake and relay
async fn do_handshake(mut self) -> Result<()> { async fn do_handshake(mut self) -> Result<()> {
// Read first bytes to determine handshake type
let mut first_bytes = [0u8; 5]; let mut first_bytes = [0u8; 5];
self.stream.read_exact(&mut first_bytes).await?; self.stream.read_exact(&mut first_bytes).await?;
let is_tls = tls::is_tls_handshake(&first_bytes[..3]); let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
let peer = self.peer; let peer = self.peer;
debug!(peer = %peer, is_tls = is_tls, first_bytes = %hex::encode(&first_bytes), "Handshake type detected"); debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
if is_tls { if is_tls {
self.handle_tls_client(first_bytes).await self.handle_tls_client(first_bytes).await
@@ -125,14 +108,9 @@ impl RunningClientHandler {
} }
} }
/// Handle TLS-wrapped client async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
async fn handle_tls_client(
mut self,
first_bytes: [u8; 5],
) -> Result<()> {
let peer = self.peer; let peer = self.peer;
// Read TLS handshake length
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake"); debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake");
@@ -140,34 +118,25 @@ impl RunningClientHandler {
if tls_len < 512 { if tls_len < 512 {
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
self.stats.increment_connects_bad(); self.stats.increment_connects_bad();
// FIX: Split stream into reader/writer for handle_bad_client
let (reader, writer) = self.stream.into_split(); let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await; handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(()); return Ok(());
} }
// Read full TLS handshake
let mut handshake = vec![0u8; 5 + tls_len]; let mut handshake = vec![0u8; 5 + tls_len];
handshake[..5].copy_from_slice(&first_bytes); handshake[..5].copy_from_slice(&first_bytes);
self.stream.read_exact(&mut handshake[5..]).await?; self.stream.read_exact(&mut handshake[5..]).await?;
// Extract fields before consuming self.stream
let config = self.config.clone(); let config = self.config.clone();
let replay_checker = self.replay_checker.clone(); let replay_checker = self.replay_checker.clone();
let stats = self.stats.clone(); let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone(); let buffer_pool = self.buffer_pool.clone();
// Split stream for reading/writing
let (read_half, write_half) = self.stream.into_split(); let (read_half, write_half) = self.stream.into_split();
// Handle TLS handshake
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
&handshake, &handshake, read_half, write_half, peer,
read_half, &config, &replay_checker, &self.rng,
write_half,
peer,
&config,
&replay_checker,
).await { ).await {
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => { HandshakeResult::BadClient { reader, writer } => {
@@ -178,83 +147,56 @@ impl RunningClientHandler {
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
// Read MTProto handshake through TLS
debug!(peer = %peer, "Reading MTProto handshake through TLS"); debug!(peer = %peer, "Reading MTProto handshake through TLS");
let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?; let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?;
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into() let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into()
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?; .map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
// Handle MTProto handshake
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
&mtproto_handshake, &mtproto_handshake, tls_reader, tls_writer, peer,
tls_reader, &config, &replay_checker, true,
tls_writer,
peer,
&config,
&replay_checker,
true,
).await { ).await {
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => { HandshakeResult::BadClient { reader: _, writer: _ } => {
stats.increment_connects_bad(); stats.increment_connects_bad();
// Valid TLS but invalid MTProto - drop debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake - dropping");
return Ok(()); return Ok(());
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
Self::handle_authenticated_static( Self::handle_authenticated_static(
crypto_reader, crypto_reader, crypto_writer, success,
crypto_writer, self.upstream_manager, self.stats, self.config,
success, buffer_pool, self.rng,
self.upstream_manager,
self.stats,
self.config,
buffer_pool
).await ).await
} }
/// Handle direct (non-TLS) client async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
async fn handle_direct_client(
mut self,
first_bytes: [u8; 5],
) -> Result<()> {
let peer = self.peer; let peer = self.peer;
// Check if non-TLS modes are enabled
if !self.config.general.modes.classic && !self.config.general.modes.secure { if !self.config.general.modes.classic && !self.config.general.modes.secure {
debug!(peer = %peer, "Non-TLS modes disabled"); debug!(peer = %peer, "Non-TLS modes disabled");
self.stats.increment_connects_bad(); self.stats.increment_connects_bad();
// FIX: Split stream into reader/writer for handle_bad_client
let (reader, writer) = self.stream.into_split(); let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await; handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(()); return Ok(());
} }
// Read rest of handshake
let mut handshake = [0u8; HANDSHAKE_LEN]; let mut handshake = [0u8; HANDSHAKE_LEN];
handshake[..5].copy_from_slice(&first_bytes); handshake[..5].copy_from_slice(&first_bytes);
self.stream.read_exact(&mut handshake[5..]).await?; self.stream.read_exact(&mut handshake[5..]).await?;
// Extract fields
let config = self.config.clone(); let config = self.config.clone();
let replay_checker = self.replay_checker.clone(); let replay_checker = self.replay_checker.clone();
let stats = self.stats.clone(); let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone(); let buffer_pool = self.buffer_pool.clone();
// Split stream
let (read_half, write_half) = self.stream.into_split(); let (read_half, write_half) = self.stream.into_split();
// Handle MTProto handshake
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
&handshake, &handshake, read_half, write_half, peer,
read_half, &config, &replay_checker, false,
write_half,
peer,
&config,
&replay_checker,
false,
).await { ).await {
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => { HandshakeResult::BadClient { reader, writer } => {
@@ -266,17 +208,12 @@ impl RunningClientHandler {
}; };
Self::handle_authenticated_static( Self::handle_authenticated_static(
crypto_reader, crypto_reader, crypto_writer, success,
crypto_writer, self.upstream_manager, self.stats, self.config,
success, buffer_pool, self.rng,
self.upstream_manager,
self.stats,
self.config,
buffer_pool
).await ).await
} }
/// Static version of handle_authenticated_inner
async fn handle_authenticated_static<R, W>( async fn handle_authenticated_static<R, W>(
client_reader: CryptoReader<R>, client_reader: CryptoReader<R>,
client_writer: CryptoWriter<W>, client_writer: CryptoWriter<W>,
@@ -285,6 +222,7 @@ impl RunningClientHandler {
stats: Arc<Stats>, stats: Arc<Stats>,
config: Arc<ProxyConfig>, config: Arc<ProxyConfig>,
buffer_pool: Arc<BufferPool>, buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
) -> Result<()> ) -> Result<()>
where where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
@@ -292,13 +230,11 @@ impl RunningClientHandler {
{ {
let user = &success.user; let user = &success.user;
// Check user limits
if let Err(e) = Self::check_user_limits_static(user, &config, &stats) { if let Err(e) = Self::check_user_limits_static(user, &config, &stats) {
warn!(user = %user, error = %e, "User limit exceeded"); warn!(user = %user, error = %e, "User limit exceeded");
return Err(e); return Err(e);
} }
// Get datacenter address
let dc_addr = Self::get_dc_addr_static(success.dc_idx, &config)?; let dc_addr = Self::get_dc_addr_static(success.dc_idx, &config)?;
info!( info!(
@@ -307,71 +243,54 @@ impl RunningClientHandler {
dc = success.dc_idx, dc = success.dc_idx,
dc_addr = %dc_addr, dc_addr = %dc_addr,
proto = ?success.proto_tag, proto = ?success.proto_tag,
fast_mode = config.general.fast_mode,
"Connecting to Telegram" "Connecting to Telegram"
); );
// Connect to Telegram via UpstreamManager // Pass dc_idx for latency-based upstream selection
let tg_stream = upstream_manager.connect(dc_addr).await?; let tg_stream = upstream_manager.connect(dc_addr, Some(success.dc_idx)).await?;
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected to Telegram, performing handshake"); debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake");
// Perform Telegram handshake and get crypto streams
let (tg_reader, tg_writer) = Self::do_tg_handshake_static( let (tg_reader, tg_writer) = Self::do_tg_handshake_static(
tg_stream, tg_stream, &success, &config, rng.as_ref(),
&success,
&config,
).await?; ).await?;
debug!(peer = %success.peer, "Telegram handshake complete, starting relay"); debug!(peer = %success.peer, "TG handshake complete, starting relay");
// Update stats
stats.increment_user_connects(user); stats.increment_user_connects(user);
stats.increment_user_curr_connects(user); stats.increment_user_curr_connects(user);
// Relay traffic using buffer pool
let relay_result = relay_bidirectional( let relay_result = relay_bidirectional(
client_reader, client_reader, client_writer,
client_writer, tg_reader, tg_writer,
tg_reader, user, Arc::clone(&stats), buffer_pool,
tg_writer,
user,
Arc::clone(&stats),
buffer_pool,
).await; ).await;
// Update stats
stats.decrement_user_curr_connects(user); stats.decrement_user_curr_connects(user);
match &relay_result { match &relay_result {
Ok(()) => debug!(user = %user, peer = %success.peer, "Relay completed normally"), Ok(()) => debug!(user = %user, "Relay completed"),
Err(e) => debug!(user = %user, peer = %success.peer, error = %e, "Relay ended with error"), Err(e) => debug!(user = %user, error = %e, "Relay ended with error"),
} }
relay_result relay_result
} }
/// Check user limits (static version)
fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> { fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> {
// Check expiration
if let Some(expiration) = config.access.user_expirations.get(user) { if let Some(expiration) = config.access.user_expirations.get(user) {
if chrono::Utc::now() > *expiration { if chrono::Utc::now() > *expiration {
return Err(ProxyError::UserExpired { user: user.to_string() }); return Err(ProxyError::UserExpired { user: user.to_string() });
} }
} }
// Check connection limit
if let Some(limit) = config.access.user_max_tcp_conns.get(user) { if let Some(limit) = config.access.user_max_tcp_conns.get(user) {
let current = stats.get_user_curr_connects(user); if stats.get_user_curr_connects(user) >= *limit as u64 {
if current >= *limit as u64 {
return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() }); return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() });
} }
} }
// Check data quota
if let Some(quota) = config.access.user_data_quota.get(user) { if let Some(quota) = config.access.user_data_quota.get(user) {
let used = stats.get_user_total_octets(user); if stats.get_user_total_octets(user) >= *quota {
if used >= *quota {
return Err(ProxyError::DataQuotaExceeded { user: user.to_string() }); return Err(ProxyError::DataQuotaExceeded { user: user.to_string() });
} }
} }
@@ -379,7 +298,6 @@ impl RunningClientHandler {
Ok(()) Ok(())
} }
/// Get datacenter address by index (static version)
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> { fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
let idx = (dc_idx.abs() - 1) as usize; let idx = (dc_idx.abs() - 1) as usize;
@@ -396,45 +314,39 @@ impl RunningClientHandler {
)) ))
} }
/// Perform handshake with Telegram server (static version)
async fn do_tg_handshake_static( async fn do_tg_handshake_static(
mut stream: TcpStream, mut stream: TcpStream,
success: &HandshakeSuccess, success: &HandshakeSuccess,
config: &ProxyConfig, config: &ProxyConfig,
rng: &SecureRandom,
) -> Result<(CryptoReader<tokio::net::tcp::OwnedReadHalf>, CryptoWriter<tokio::net::tcp::OwnedWriteHalf>)> { ) -> Result<(CryptoReader<tokio::net::tcp::OwnedReadHalf>, CryptoWriter<tokio::net::tcp::OwnedWriteHalf>)> {
// Generate nonce with keys for TG
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce( let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce(
success.proto_tag, success.proto_tag,
&success.dec_key, // Client's dec key &success.dec_key,
success.dec_iv, success.dec_iv,
rng,
config.general.fast_mode, config.general.fast_mode,
); );
// Encrypt nonce
let encrypted_nonce = encrypt_tg_nonce(&nonce); let encrypted_nonce = encrypt_tg_nonce(&nonce);
debug!( debug!(
peer = %success.peer, peer = %success.peer,
nonce_head = %hex::encode(&nonce[..16]), nonce_head = %hex::encode(&nonce[..16]),
encrypted_head = %hex::encode(&encrypted_nonce[..16]),
"Sending nonce to Telegram" "Sending nonce to Telegram"
); );
// Send to Telegram
stream.write_all(&encrypted_nonce).await?; stream.write_all(&encrypted_nonce).await?;
stream.flush().await?; stream.flush().await?;
debug!(peer = %success.peer, "Nonce sent to Telegram");
// Split stream and wrap with crypto
let (read_half, write_half) = stream.into_split(); let (read_half, write_half) = stream.into_split();
let decryptor = AesCtr::new(&tg_dec_key, tg_dec_iv); let decryptor = AesCtr::new(&tg_dec_key, tg_dec_iv);
let encryptor = AesCtr::new(&tg_enc_key, tg_enc_iv); let encryptor = AesCtr::new(&tg_enc_key, tg_enc_iv);
let tg_reader = CryptoReader::new(read_half, decryptor); Ok((
let tg_writer = CryptoWriter::new(write_half, encryptor); CryptoReader::new(read_half, decryptor),
CryptoWriter::new(write_half, encryptor),
Ok((tg_reader, tg_writer)) ))
} }
} }

View File

@@ -1,11 +1,11 @@
//! MTProto Handshake Magics //! MTProto Handshake
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tracing::{debug, warn, trace, info}; use tracing::{debug, warn, trace, info};
use zeroize::Zeroize;
use crate::crypto::{sha256, AesCtr}; use crate::crypto::{sha256, AesCtr, SecureRandom};
use crate::crypto::random::SECURE_RANDOM;
use crate::protocol::constants::*; use crate::protocol::constants::*;
use crate::protocol::tls; use crate::protocol::tls;
use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter}; use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter};
@@ -14,6 +14,9 @@ use crate::stats::ReplayChecker;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
/// Result of successful handshake /// Result of successful handshake
///
/// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is
/// zeroized on drop.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct HandshakeSuccess { pub struct HandshakeSuccess {
/// Authenticated user name /// Authenticated user name
@@ -34,6 +37,15 @@ pub struct HandshakeSuccess {
pub is_tls: bool, pub is_tls: bool,
} }
impl Drop for HandshakeSuccess {
fn drop(&mut self) {
self.dec_key.zeroize();
self.dec_iv.zeroize();
self.enc_key.zeroize();
self.enc_iv.zeroize();
}
}
/// Handle fake TLS handshake /// Handle fake TLS handshake
pub async fn handle_tls_handshake<R, W>( pub async fn handle_tls_handshake<R, W>(
handshake: &[u8], handshake: &[u8],
@@ -42,6 +54,7 @@ pub async fn handle_tls_handshake<R, W>(
peer: SocketAddr, peer: SocketAddr,
config: &ProxyConfig, config: &ProxyConfig,
replay_checker: &ReplayChecker, replay_checker: &ReplayChecker,
rng: &SecureRandom,
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String), R, W> ) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String), R, W>
where where
R: AsyncRead + Unpin, R: AsyncRead + Unpin,
@@ -49,30 +62,25 @@ where
{ {
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake"); debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
// Check minimum length
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
debug!(peer = %peer, "TLS handshake too short"); debug!(peer = %peer, "TLS handshake too short");
return HandshakeResult::BadClient { reader, writer }; return HandshakeResult::BadClient { reader, writer };
} }
// Extract digest for replay check
let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]; let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN];
let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN]; let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN];
// Check for replay
if replay_checker.check_tls_digest(digest_half) { if replay_checker.check_tls_digest(digest_half) {
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient { reader, writer }; return HandshakeResult::BadClient { reader, writer };
} }
// Build secrets list
let secrets: Vec<(String, Vec<u8>)> = config.access.users.iter() let secrets: Vec<(String, Vec<u8>)> = config.access.users.iter()
.filter_map(|(name, hex)| { .filter_map(|(name, hex)| {
hex::decode(hex).ok().map(|bytes| (name.clone(), bytes)) hex::decode(hex).ok().map(|bytes| (name.clone(), bytes))
}) })
.collect(); .collect();
// Validate handshake
let validation = match tls::validate_tls_handshake( let validation = match tls::validate_tls_handshake(
handshake, handshake,
&secrets, &secrets,
@@ -89,18 +97,17 @@ where
} }
}; };
// Get secret for response
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
Some((_, s)) => s, Some((_, s)) => s,
None => return HandshakeResult::BadClient { reader, writer }, None => return HandshakeResult::BadClient { reader, writer },
}; };
// Build and send response
let response = tls::build_server_hello( let response = tls::build_server_hello(
secret, secret,
&validation.digest, &validation.digest,
&validation.session_id, &validation.session_id,
config.censorship.fake_cert_len, config.censorship.fake_cert_len,
rng,
); );
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello"); debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
@@ -115,7 +122,6 @@ where
return HandshakeResult::Error(ProxyError::Io(e)); return HandshakeResult::Error(ProxyError::Io(e));
} }
// Record for replay protection only after successful handshake
replay_checker.add_tls_digest(digest_half); replay_checker.add_tls_digest(digest_half);
info!( info!(
@@ -147,26 +153,21 @@ where
{ {
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes"); trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
// Extract prekey and IV
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
// Check for replay
if replay_checker.check_handshake(dec_prekey_iv) { if replay_checker.check_handshake(dec_prekey_iv) {
warn!(peer = %peer, "MTProto replay attack detected"); warn!(peer = %peer, "MTProto replay attack detected");
return HandshakeResult::BadClient { reader, writer }; return HandshakeResult::BadClient { reader, writer };
} }
// Reversed for encryption direction
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect(); let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
// Try each user's secret
for (user, secret_hex) in &config.access.users { for (user, secret_hex) in &config.access.users {
let secret = match hex::decode(secret_hex) { let secret = match hex::decode(secret_hex) {
Ok(s) => s, Ok(s) => s,
Err(_) => continue, Err(_) => continue,
}; };
// Derive decryption key
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
@@ -177,11 +178,9 @@ where
let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap()); let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
// Decrypt handshake to check protocol tag
let mut decryptor = AesCtr::new(&dec_key, dec_iv); let mut decryptor = AesCtr::new(&dec_key, dec_iv);
let decrypted = decryptor.decrypt(handshake); let decrypted = decryptor.decrypt(handshake);
// Check protocol tag
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
.try_into() .try_into()
.unwrap(); .unwrap();
@@ -191,7 +190,6 @@ where
None => continue, None => continue,
}; };
// Check if mode is enabled
let mode_ok = match proto_tag { let mode_ok = match proto_tag {
ProtoTag::Secure => { ProtoTag::Secure => {
if is_tls { config.general.modes.tls } else { config.general.modes.secure } if is_tls { config.general.modes.tls } else { config.general.modes.secure }
@@ -204,12 +202,10 @@ where
continue; continue;
} }
// Extract DC index
let dc_idx = i16::from_le_bytes( let dc_idx = i16::from_le_bytes(
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap() decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
); );
// Derive encryption key
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
@@ -220,10 +216,8 @@ where
let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap()); let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap());
// Record for replay protection
replay_checker.add_handshake(dec_prekey_iv); replay_checker.add_handshake(dec_prekey_iv);
// Create new cipher instances
let decryptor = AesCtr::new(&dec_key, dec_iv); let decryptor = AesCtr::new(&dec_key, dec_iv);
let encryptor = AesCtr::new(&enc_key, enc_iv); let encryptor = AesCtr::new(&enc_key, enc_iv);
@@ -264,10 +258,11 @@ pub fn generate_tg_nonce(
proto_tag: ProtoTag, proto_tag: ProtoTag,
client_dec_key: &[u8; 32], client_dec_key: &[u8; 32],
client_dec_iv: u128, client_dec_iv: u128,
rng: &SecureRandom,
fast_mode: bool, fast_mode: bool,
) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) { ) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) {
loop { loop {
let bytes = SECURE_RANDOM.bytes(HANDSHAKE_LEN); let bytes = rng.bytes(HANDSHAKE_LEN);
let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap(); let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap();
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; } if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; }
@@ -323,13 +318,12 @@ mod tests {
let client_dec_key = [0x42u8; 32]; let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128; let client_dec_iv = 12345u128;
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = let rng = SecureRandom::new();
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false); let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) =
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, &rng, false);
// Check length
assert_eq!(nonce.len(), HANDSHAKE_LEN); assert_eq!(nonce.len(), HANDSHAKE_LEN);
// Check proto tag is set
let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap(); let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap();
assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure)); assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure));
} }
@@ -339,17 +333,35 @@ mod tests {
let client_dec_key = [0x42u8; 32]; let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128; let client_dec_iv = 12345u128;
let rng = SecureRandom::new();
let (nonce, _, _, _, _) = let (nonce, _, _, _, _) =
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false); generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, &rng, false);
let encrypted = encrypt_tg_nonce(&nonce); let encrypted = encrypt_tg_nonce(&nonce);
assert_eq!(encrypted.len(), HANDSHAKE_LEN); assert_eq!(encrypted.len(), HANDSHAKE_LEN);
// First PROTO_TAG_POS bytes should be unchanged
assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]); assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]);
// Rest should be different (encrypted)
assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]); assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]);
} }
#[test]
fn test_handshake_success_zeroize_on_drop() {
let success = HandshakeSuccess {
user: "test".to_string(),
dc_idx: 2,
proto_tag: ProtoTag::Secure,
dec_key: [0xAA; 32],
dec_iv: 0xBBBBBBBB,
enc_key: [0xCC; 32],
enc_iv: 0xDDDDDDDD,
peer: "127.0.0.1:1234".parse().unwrap(),
is_tls: true,
};
assert_eq!(success.dec_key, [0xAA; 32]);
assert_eq!(success.enc_key, [0xCC; 32]);
drop(success);
// Drop impl zeroizes key material without panic
}
} }

View File

@@ -1,31 +1,28 @@
//! Statistics //! Statistics and replay protection
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::{Instant, Duration};
use dashmap::DashMap; use dashmap::DashMap;
use parking_lot::{RwLock, Mutex}; use parking_lot::Mutex;
use lru::LruCache; use lru::LruCache;
use std::num::NonZeroUsize; use std::num::NonZeroUsize;
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher; use std::collections::hash_map::DefaultHasher;
use std::collections::VecDeque;
use tracing::debug;
// ============= Stats =============
/// Thread-safe statistics
#[derive(Default)] #[derive(Default)]
pub struct Stats { pub struct Stats {
// Global counters
connects_all: AtomicU64, connects_all: AtomicU64,
connects_bad: AtomicU64, connects_bad: AtomicU64,
handshake_timeouts: AtomicU64, handshake_timeouts: AtomicU64,
// Per-user stats
user_stats: DashMap<String, UserStats>, user_stats: DashMap<String, UserStats>,
start_time: parking_lot::RwLock<Option<Instant>>,
// Start time
start_time: RwLock<Option<Instant>>,
} }
/// Per-user statistics
#[derive(Default)] #[derive(Default)]
pub struct UserStats { pub struct UserStats {
pub connects: AtomicU64, pub connects: AtomicU64,
@@ -43,42 +40,20 @@ impl Stats {
stats stats
} }
// Global stats pub fn increment_connects_all(&self) { self.connects_all.fetch_add(1, Ordering::Relaxed); }
pub fn increment_connects_all(&self) { pub fn increment_connects_bad(&self) { self.connects_bad.fetch_add(1, Ordering::Relaxed); }
self.connects_all.fetch_add(1, Ordering::Relaxed); pub fn increment_handshake_timeouts(&self) { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); }
} pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) }
pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) }
pub fn increment_connects_bad(&self) {
self.connects_bad.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_handshake_timeouts(&self) {
self.handshake_timeouts.fetch_add(1, Ordering::Relaxed);
}
pub fn get_connects_all(&self) -> u64 {
self.connects_all.load(Ordering::Relaxed)
}
pub fn get_connects_bad(&self) -> u64 {
self.connects_bad.load(Ordering::Relaxed)
}
// User stats
pub fn increment_user_connects(&self, user: &str) { pub fn increment_user_connects(&self, user: &str) {
self.user_stats self.user_stats.entry(user.to_string()).or_default()
.entry(user.to_string()) .connects.fetch_add(1, Ordering::Relaxed);
.or_default()
.connects
.fetch_add(1, Ordering::Relaxed);
} }
pub fn increment_user_curr_connects(&self, user: &str) { pub fn increment_user_curr_connects(&self, user: &str) {
self.user_stats self.user_stats.entry(user.to_string()).or_default()
.entry(user.to_string()) .curr_connects.fetch_add(1, Ordering::Relaxed);
.or_default()
.curr_connects
.fetch_add(1, Ordering::Relaxed);
} }
pub fn decrement_user_curr_connects(&self, user: &str) { pub fn decrement_user_curr_connects(&self, user: &str) {
@@ -88,47 +63,33 @@ impl Stats {
} }
pub fn get_user_curr_connects(&self, user: &str) -> u64 { pub fn get_user_curr_connects(&self, user: &str) -> u64 {
self.user_stats self.user_stats.get(user)
.get(user)
.map(|s| s.curr_connects.load(Ordering::Relaxed)) .map(|s| s.curr_connects.load(Ordering::Relaxed))
.unwrap_or(0) .unwrap_or(0)
} }
pub fn add_user_octets_from(&self, user: &str, bytes: u64) { pub fn add_user_octets_from(&self, user: &str, bytes: u64) {
self.user_stats self.user_stats.entry(user.to_string()).or_default()
.entry(user.to_string()) .octets_from_client.fetch_add(bytes, Ordering::Relaxed);
.or_default()
.octets_from_client
.fetch_add(bytes, Ordering::Relaxed);
} }
pub fn add_user_octets_to(&self, user: &str, bytes: u64) { pub fn add_user_octets_to(&self, user: &str, bytes: u64) {
self.user_stats self.user_stats.entry(user.to_string()).or_default()
.entry(user.to_string()) .octets_to_client.fetch_add(bytes, Ordering::Relaxed);
.or_default()
.octets_to_client
.fetch_add(bytes, Ordering::Relaxed);
} }
pub fn increment_user_msgs_from(&self, user: &str) { pub fn increment_user_msgs_from(&self, user: &str) {
self.user_stats self.user_stats.entry(user.to_string()).or_default()
.entry(user.to_string()) .msgs_from_client.fetch_add(1, Ordering::Relaxed);
.or_default()
.msgs_from_client
.fetch_add(1, Ordering::Relaxed);
} }
pub fn increment_user_msgs_to(&self, user: &str) { pub fn increment_user_msgs_to(&self, user: &str) {
self.user_stats self.user_stats.entry(user.to_string()).or_default()
.entry(user.to_string()) .msgs_to_client.fetch_add(1, Ordering::Relaxed);
.or_default()
.msgs_to_client
.fetch_add(1, Ordering::Relaxed);
} }
pub fn get_user_total_octets(&self, user: &str) -> u64 { pub fn get_user_total_octets(&self, user: &str) -> u64 {
self.user_stats self.user_stats.get(user)
.get(user)
.map(|s| { .map(|s| {
s.octets_from_client.load(Ordering::Relaxed) + s.octets_from_client.load(Ordering::Relaxed) +
s.octets_to_client.load(Ordering::Relaxed) s.octets_to_client.load(Ordering::Relaxed)
@@ -143,57 +104,209 @@ impl Stats {
} }
} }
/// Sharded Replay attack checker using LRU cache // ============= Replay Checker =============
/// Uses multiple independent LRU caches to reduce lock contention
pub struct ReplayChecker { pub struct ReplayChecker {
shards: Vec<Mutex<LruCache<Vec<u8>, ()>>>, shards: Vec<Mutex<ReplayShard>>,
shard_mask: usize, shard_mask: usize,
window: Duration,
checks: AtomicU64,
hits: AtomicU64,
additions: AtomicU64,
cleanups: AtomicU64,
}
struct ReplayEntry {
seen_at: Instant,
seq: u64,
}
struct ReplayShard {
cache: LruCache<Box<[u8]>, ReplayEntry>,
queue: VecDeque<(Instant, Box<[u8]>, u64)>,
seq_counter: u64,
}
impl ReplayShard {
fn new(cap: NonZeroUsize) -> Self {
Self {
cache: LruCache::new(cap),
queue: VecDeque::with_capacity(cap.get()),
seq_counter: 0,
}
}
fn next_seq(&mut self) -> u64 {
self.seq_counter += 1;
self.seq_counter
}
fn cleanup(&mut self, now: Instant, window: Duration) {
if window.is_zero() {
return;
}
let cutoff = now.checked_sub(window).unwrap_or(now);
while let Some((ts, _, _)) = self.queue.front() {
if *ts >= cutoff {
break;
}
let (_, key, queue_seq) = self.queue.pop_front().unwrap();
// Use key.as_ref() to get &[u8] — avoids Borrow<Q> ambiguity
// between Borrow<[u8]> and Borrow<Box<[u8]>>
if let Some(entry) = self.cache.peek(key.as_ref()) {
if entry.seq == queue_seq {
self.cache.pop(key.as_ref());
}
}
}
}
fn check(&mut self, key: &[u8], now: Instant, window: Duration) -> bool {
self.cleanup(now, window);
// key is &[u8], resolves Q=[u8] via Box<[u8]>: Borrow<[u8]>
self.cache.get(key).is_some()
}
fn add(&mut self, key: &[u8], now: Instant, window: Duration) {
self.cleanup(now, window);
let seq = self.next_seq();
let boxed_key: Box<[u8]> = key.into();
self.cache.put(boxed_key.clone(), ReplayEntry { seen_at: now, seq });
self.queue.push_back((now, boxed_key, seq));
}
fn len(&self) -> usize {
self.cache.len()
}
} }
impl ReplayChecker { impl ReplayChecker {
/// Create new replay checker with specified capacity per shard pub fn new(total_capacity: usize, window: Duration) -> Self {
/// Total capacity = capacity * num_shards
pub fn new(total_capacity: usize) -> Self {
// Use 64 shards for good concurrency
let num_shards = 64; let num_shards = 64;
let shard_capacity = (total_capacity / num_shards).max(1); let shard_capacity = (total_capacity / num_shards).max(1);
let cap = NonZeroUsize::new(shard_capacity).unwrap(); let cap = NonZeroUsize::new(shard_capacity).unwrap();
let mut shards = Vec::with_capacity(num_shards); let mut shards = Vec::with_capacity(num_shards);
for _ in 0..num_shards { for _ in 0..num_shards {
shards.push(Mutex::new(LruCache::new(cap))); shards.push(Mutex::new(ReplayShard::new(cap)));
} }
Self { Self {
shards, shards,
shard_mask: num_shards - 1, shard_mask: num_shards - 1,
window,
checks: AtomicU64::new(0),
hits: AtomicU64::new(0),
additions: AtomicU64::new(0),
cleanups: AtomicU64::new(0),
} }
} }
fn get_shard(&self, key: &[u8]) -> usize { fn get_shard_idx(&self, key: &[u8]) -> usize {
let mut hasher = DefaultHasher::new(); let mut hasher = DefaultHasher::new();
key.hash(&mut hasher); key.hash(&mut hasher);
(hasher.finish() as usize) & self.shard_mask (hasher.finish() as usize) & self.shard_mask
} }
pub fn check_handshake(&self, data: &[u8]) -> bool { fn check(&self, data: &[u8]) -> bool {
let shard_idx = self.get_shard(data); self.checks.fetch_add(1, Ordering::Relaxed);
self.shards[shard_idx].lock().contains(&data.to_vec()) let idx = self.get_shard_idx(data);
let mut shard = self.shards[idx].lock();
let found = shard.check(data, Instant::now(), self.window);
if found {
self.hits.fetch_add(1, Ordering::Relaxed);
}
found
} }
pub fn add_handshake(&self, data: &[u8]) { fn add(&self, data: &[u8]) {
let shard_idx = self.get_shard(data); self.additions.fetch_add(1, Ordering::Relaxed);
self.shards[shard_idx].lock().put(data.to_vec(), ()); let idx = self.get_shard_idx(data);
let mut shard = self.shards[idx].lock();
shard.add(data, Instant::now(), self.window);
} }
pub fn check_tls_digest(&self, data: &[u8]) -> bool { pub fn check_handshake(&self, data: &[u8]) -> bool { self.check(data) }
let shard_idx = self.get_shard(data); pub fn add_handshake(&self, data: &[u8]) { self.add(data) }
self.shards[shard_idx].lock().contains(&data.to_vec()) pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check(data) }
pub fn add_tls_digest(&self, data: &[u8]) { self.add(data) }
pub fn stats(&self) -> ReplayStats {
let mut total_entries = 0;
let mut total_queue_len = 0;
for shard in &self.shards {
let s = shard.lock();
total_entries += s.cache.len();
total_queue_len += s.queue.len();
} }
pub fn add_tls_digest(&self, data: &[u8]) { ReplayStats {
let shard_idx = self.get_shard(data); total_entries,
self.shards[shard_idx].lock().put(data.to_vec(), ()); total_queue_len,
total_checks: self.checks.load(Ordering::Relaxed),
total_hits: self.hits.load(Ordering::Relaxed),
total_additions: self.additions.load(Ordering::Relaxed),
total_cleanups: self.cleanups.load(Ordering::Relaxed),
num_shards: self.shards.len(),
window_secs: self.window.as_secs(),
}
}
pub async fn run_periodic_cleanup(&self) {
let interval = if self.window.as_secs() > 60 {
Duration::from_secs(30)
} else {
Duration::from_secs(self.window.as_secs().max(1) / 2)
};
loop {
tokio::time::sleep(interval).await;
let now = Instant::now();
let mut cleaned = 0usize;
for shard_mutex in &self.shards {
let mut shard = shard_mutex.lock();
let before = shard.len();
shard.cleanup(now, self.window);
let after = shard.len();
cleaned += before.saturating_sub(after);
}
self.cleanups.fetch_add(1, Ordering::Relaxed);
if cleaned > 0 {
debug!(cleaned = cleaned, "Replay checker: periodic cleanup");
}
}
}
}
#[derive(Debug, Clone)]
pub struct ReplayStats {
pub total_entries: usize,
pub total_queue_len: usize,
pub total_checks: u64,
pub total_hits: u64,
pub total_additions: u64,
pub total_cleanups: u64,
pub num_shards: usize,
pub window_secs: u64,
}
impl ReplayStats {
pub fn hit_rate(&self) -> f64 {
if self.total_checks == 0 { 0.0 }
else { (self.total_hits as f64 / self.total_checks as f64) * 100.0 }
}
pub fn ghost_ratio(&self) -> f64 {
if self.total_entries == 0 { 0.0 }
else { self.total_queue_len as f64 / self.total_entries as f64 }
} }
} }
@@ -204,28 +317,60 @@ mod tests {
#[test] #[test]
fn test_stats_shared_counters() { fn test_stats_shared_counters() {
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
stats.increment_connects_all();
let stats1 = Arc::clone(&stats); stats.increment_connects_all();
let stats2 = Arc::clone(&stats); stats.increment_connects_all();
stats1.increment_connects_all();
stats2.increment_connects_all();
stats1.increment_connects_all();
assert_eq!(stats.get_connects_all(), 3); assert_eq!(stats.get_connects_all(), 3);
} }
#[test] #[test]
fn test_replay_checker_sharding() { fn test_replay_checker_basic() {
let checker = ReplayChecker::new(100); let checker = ReplayChecker::new(100, Duration::from_secs(60));
let data1 = b"test1"; assert!(!checker.check_handshake(b"test1"));
let data2 = b"test2"; checker.add_handshake(b"test1");
assert!(checker.check_handshake(b"test1"));
assert!(!checker.check_handshake(b"test2"));
}
checker.add_handshake(data1); #[test]
assert!(checker.check_handshake(data1)); fn test_replay_checker_duplicate_add() {
assert!(!checker.check_handshake(data2)); let checker = ReplayChecker::new(100, Duration::from_secs(60));
checker.add_handshake(b"dup");
checker.add_handshake(b"dup");
assert!(checker.check_handshake(b"dup"));
}
checker.add_handshake(data2); #[test]
assert!(checker.check_handshake(data2)); fn test_replay_checker_expiration() {
let checker = ReplayChecker::new(100, Duration::from_millis(50));
checker.add_handshake(b"expire");
assert!(checker.check_handshake(b"expire"));
std::thread::sleep(Duration::from_millis(100));
assert!(!checker.check_handshake(b"expire"));
}
#[test]
fn test_replay_checker_stats() {
let checker = ReplayChecker::new(100, Duration::from_secs(60));
checker.add_handshake(b"k1");
checker.add_handshake(b"k2");
checker.check_handshake(b"k1");
checker.check_handshake(b"k3");
let stats = checker.stats();
assert_eq!(stats.total_additions, 2);
assert_eq!(stats.total_checks, 2);
assert_eq!(stats.total_hits, 1);
}
#[test]
fn test_replay_checker_many_keys() {
let checker = ReplayChecker::new(1000, Duration::from_secs(60));
for i in 0..500u32 {
checker.add(&i.to_le_bytes());
}
for i in 0..500u32 {
assert!(checker.check(&i.to_le_bytes()));
}
assert_eq!(checker.stats().total_entries, 500);
} }
} }

View File

@@ -5,8 +5,10 @@
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use std::io::Result; use std::io::Result;
use std::sync::Arc;
use crate::protocol::constants::ProtoTag; use crate::protocol::constants::ProtoTag;
use crate::crypto::SecureRandom;
// ============= Frame Types ============= // ============= Frame Types =============
@@ -147,11 +149,11 @@ pub trait FrameCodec: Send + Sync {
// ============= Codec Factory ============= // ============= Codec Factory =============
/// Create a frame codec for the given protocol tag /// Create a frame codec for the given protocol tag
pub fn create_codec(proto_tag: ProtoTag) -> Box<dyn FrameCodec> { pub fn create_codec(proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Box<dyn FrameCodec> {
match proto_tag { match proto_tag {
ProtoTag::Abridged => Box::new(crate::stream::frame_codec::AbridgedCodec::new()), ProtoTag::Abridged => Box::new(crate::stream::frame_codec::AbridgedCodec::new()),
ProtoTag::Intermediate => Box::new(crate::stream::frame_codec::IntermediateCodec::new()), ProtoTag::Intermediate => Box::new(crate::stream::frame_codec::IntermediateCodec::new()),
ProtoTag::Secure => Box::new(crate::stream::frame_codec::SecureCodec::new()), ProtoTag::Secure => Box::new(crate::stream::frame_codec::SecureCodec::new(rng)),
} }
} }

View File

@@ -5,9 +5,11 @@
use bytes::{Bytes, BytesMut, BufMut}; use bytes::{Bytes, BytesMut, BufMut};
use std::io::{self, Error, ErrorKind}; use std::io::{self, Error, ErrorKind};
use std::sync::Arc;
use tokio_util::codec::{Decoder, Encoder}; use tokio_util::codec::{Decoder, Encoder};
use crate::protocol::constants::ProtoTag; use crate::protocol::constants::ProtoTag;
use crate::crypto::SecureRandom;
use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait}; use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait};
// ============= Unified Codec ============= // ============= Unified Codec =============
@@ -21,14 +23,17 @@ pub struct FrameCodec {
proto_tag: ProtoTag, proto_tag: ProtoTag,
/// Maximum allowed frame size /// Maximum allowed frame size
max_frame_size: usize, max_frame_size: usize,
/// RNG for secure padding
rng: Arc<SecureRandom>,
} }
impl FrameCodec { impl FrameCodec {
/// Create a new codec for the given protocol /// Create a new codec for the given protocol
pub fn new(proto_tag: ProtoTag) -> Self { pub fn new(proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Self {
Self { Self {
proto_tag, proto_tag,
max_frame_size: 16 * 1024 * 1024, // 16MB default max_frame_size: 16 * 1024 * 1024, // 16MB default
rng,
} }
} }
@@ -64,7 +69,7 @@ impl Encoder<Frame> for FrameCodec {
match self.proto_tag { match self.proto_tag {
ProtoTag::Abridged => encode_abridged(&frame, dst), ProtoTag::Abridged => encode_abridged(&frame, dst),
ProtoTag::Intermediate => encode_intermediate(&frame, dst), ProtoTag::Intermediate => encode_intermediate(&frame, dst),
ProtoTag::Secure => encode_secure(&frame, dst), ProtoTag::Secure => encode_secure(&frame, dst, &self.rng),
} }
} }
} }
@@ -288,9 +293,7 @@ fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame
Ok(Some(Frame::with_meta(data, meta))) Ok(Some(Frame::with_meta(data, meta)))
} }
fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> { fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::Result<()> {
use crate::crypto::random::SECURE_RANDOM;
let data = &frame.data; let data = &frame.data;
// Simple ACK: just send data // Simple ACK: just send data
@@ -303,10 +306,10 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
// Generate padding to make length not divisible by 4 // Generate padding to make length not divisible by 4
let padding_len = if data.len() % 4 == 0 { let padding_len = if data.len() % 4 == 0 {
// Add 1-3 bytes to make it non-aligned // Add 1-3 bytes to make it non-aligned
(SECURE_RANDOM.range(3) + 1) as usize (rng.range(3) + 1) as usize
} else { } else {
// Already non-aligned, can add 0-3 // Already non-aligned, can add 0-3
SECURE_RANDOM.range(4) as usize rng.range(4) as usize
}; };
let total_len = data.len() + padding_len; let total_len = data.len() + padding_len;
@@ -321,7 +324,7 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
dst.extend_from_slice(data); dst.extend_from_slice(data);
if padding_len > 0 { if padding_len > 0 {
let padding = SECURE_RANDOM.bytes(padding_len); let padding = rng.bytes(padding_len);
dst.extend_from_slice(&padding); dst.extend_from_slice(&padding);
} }
@@ -445,19 +448,21 @@ impl FrameCodecTrait for IntermediateCodec {
/// Secure Intermediate protocol codec /// Secure Intermediate protocol codec
pub struct SecureCodec { pub struct SecureCodec {
max_frame_size: usize, max_frame_size: usize,
rng: Arc<SecureRandom>,
} }
impl SecureCodec { impl SecureCodec {
pub fn new() -> Self { pub fn new(rng: Arc<SecureRandom>) -> Self {
Self { Self {
max_frame_size: 16 * 1024 * 1024, max_frame_size: 16 * 1024 * 1024,
rng,
} }
} }
} }
impl Default for SecureCodec { impl Default for SecureCodec {
fn default() -> Self { fn default() -> Self {
Self::new() Self::new(Arc::new(SecureRandom::new()))
} }
} }
@@ -474,7 +479,7 @@ impl Encoder<Frame> for SecureCodec {
type Error = io::Error; type Error = io::Error;
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> { fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
encode_secure(&frame, dst) encode_secure(&frame, dst, &self.rng)
} }
} }
@@ -485,7 +490,7 @@ impl FrameCodecTrait for SecureCodec {
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> { fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
let before = dst.len(); let before = dst.len();
encode_secure(frame, dst)?; encode_secure(frame, dst, &self.rng)?;
Ok(dst.len() - before) Ok(dst.len() - before)
} }
@@ -506,6 +511,8 @@ mod tests {
use tokio_util::codec::{FramedRead, FramedWrite}; use tokio_util::codec::{FramedRead, FramedWrite};
use tokio::io::duplex; use tokio::io::duplex;
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use crate::crypto::SecureRandom;
use std::sync::Arc;
#[tokio::test] #[tokio::test]
async fn test_framed_abridged() { async fn test_framed_abridged() {
@@ -541,8 +548,8 @@ mod tests {
async fn test_framed_secure() { async fn test_framed_secure() {
let (client, server) = duplex(4096); let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, SecureCodec::new()); let mut writer = FramedWrite::new(client, SecureCodec::new(Arc::new(SecureRandom::new())));
let mut reader = FramedRead::new(server, SecureCodec::new()); let mut reader = FramedRead::new(server, SecureCodec::new(Arc::new(SecureRandom::new())));
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]); let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
let frame = Frame::new(original.clone()); let frame = Frame::new(original.clone());
@@ -557,8 +564,8 @@ mod tests {
for proto_tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] { for proto_tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] {
let (client, server) = duplex(4096); let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag)); let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())));
let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag)); let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())));
// Use 4-byte aligned data for abridged compatibility // Use 4-byte aligned data for abridged compatibility
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]); let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
@@ -607,7 +614,7 @@ mod tests {
#[test] #[test]
fn test_frame_too_large() { fn test_frame_too_large() {
let mut codec = FrameCodec::new(ProtoTag::Intermediate) let mut codec = FrameCodec::new(ProtoTag::Intermediate, Arc::new(SecureRandom::new()))
.with_max_frame_size(100); .with_max_frame_size(100);
// Create a "frame" that claims to be very large // Create a "frame" that claims to be very large

View File

@@ -4,8 +4,8 @@ use bytes::{Bytes, BytesMut};
use std::io::{Error, ErrorKind, Result}; use std::io::{Error, ErrorKind, Result};
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use crate::protocol::constants::*; use crate::protocol::constants::*;
use crate::crypto::crc32; use crate::crypto::{crc32, SecureRandom};
use crate::crypto::random::SECURE_RANDOM; use std::sync::Arc;
use super::traits::{FrameMeta, LayeredStream}; use super::traits::{FrameMeta, LayeredStream};
// ============= Abridged (Compact) Frame ============= // ============= Abridged (Compact) Frame =============
@@ -251,11 +251,12 @@ impl<R> LayeredStream<R> for SecureIntermediateFrameReader<R> {
/// Writer for secure intermediate MTProto framing /// Writer for secure intermediate MTProto framing
pub struct SecureIntermediateFrameWriter<W> { pub struct SecureIntermediateFrameWriter<W> {
upstream: W, upstream: W,
rng: Arc<SecureRandom>,
} }
impl<W> SecureIntermediateFrameWriter<W> { impl<W> SecureIntermediateFrameWriter<W> {
pub fn new(upstream: W) -> Self { pub fn new(upstream: W, rng: Arc<SecureRandom>) -> Self {
Self { upstream } Self { upstream, rng }
} }
} }
@@ -267,8 +268,8 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
} }
// Add random padding (0-3 bytes) // Add random padding (0-3 bytes)
let padding_len = SECURE_RANDOM.range(4); let padding_len = self.rng.range(4);
let padding = SECURE_RANDOM.bytes(padding_len); let padding = self.rng.bytes(padding_len);
let total_len = data.len() + padding_len; let total_len = data.len() + padding_len;
let len_bytes = (total_len as u32).to_le_bytes(); let len_bytes = (total_len as u32).to_le_bytes();
@@ -454,11 +455,11 @@ pub enum FrameWriterKind<W> {
} }
impl<W: AsyncWrite + Unpin> FrameWriterKind<W> { impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
pub fn new(upstream: W, proto_tag: ProtoTag) -> Self { pub fn new(upstream: W, proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Self {
match proto_tag { match proto_tag {
ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)), ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)),
ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)), ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)),
ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream)), ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream, rng)),
} }
} }
@@ -483,6 +484,8 @@ impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
mod tests { mod tests {
use super::*; use super::*;
use tokio::io::duplex; use tokio::io::duplex;
use std::sync::Arc;
use crate::crypto::SecureRandom;
#[tokio::test] #[tokio::test]
async fn test_abridged_roundtrip() { async fn test_abridged_roundtrip() {
@@ -539,7 +542,7 @@ mod tests {
async fn test_secure_intermediate_padding() { async fn test_secure_intermediate_padding() {
let (client, server) = duplex(1024); let (client, server) = duplex(1024);
let mut writer = SecureIntermediateFrameWriter::new(client); let mut writer = SecureIntermediateFrameWriter::new(client, Arc::new(SecureRandom::new()));
let mut reader = SecureIntermediateFrameReader::new(server); let mut reader = SecureIntermediateFrameReader::new(server);
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8]; let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
@@ -572,7 +575,7 @@ mod tests {
async fn test_frame_reader_kind() { async fn test_frame_reader_kind() {
let (client, server) = duplex(1024); let (client, server) = duplex(1024);
let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate); let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate, Arc::new(SecureRandom::new()));
let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate); let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate);
let data = vec![1u8, 2, 3, 4]; let data = vec![1u8, 2, 3, 4];

View File

@@ -10,4 +10,4 @@ pub use pool::ConnectionPool;
pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol}; pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol};
pub use socket::*; pub use socket::*;
pub use socks::*; pub use socks::*;
pub use upstream::UpstreamManager; pub use upstream::{UpstreamManager, StartupPingResult, DcPingResult};

View File

@@ -1,26 +1,113 @@
//! Upstream Management //! Upstream Management with per-DC latency-weighted selection
use std::net::{SocketAddr, IpAddr}; use std::net::{SocketAddr, IpAddr};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tokio::time::Instant;
use rand::Rng; use rand::Rng;
use tracing::{debug, warn, error, info}; use tracing::{debug, warn, info, trace};
use crate::config::{UpstreamConfig, UpstreamType}; use crate::config::{UpstreamConfig, UpstreamType};
use crate::error::{Result, ProxyError}; use crate::error::{Result, ProxyError};
use crate::protocol::constants::{TG_DATACENTERS_V4, TG_DATACENTERS_V6, TG_DATACENTER_PORT};
use crate::transport::socket::create_outgoing_socket_bound; use crate::transport::socket::create_outgoing_socket_bound;
use crate::transport::socks::{connect_socks4, connect_socks5}; use crate::transport::socks::{connect_socks4, connect_socks5};
/// Number of Telegram datacenters
const NUM_DCS: usize = 5;
// ============= RTT Tracking =============
#[derive(Debug, Clone, Copy)]
struct LatencyEma {
value_ms: Option<f64>,
alpha: f64,
}
impl LatencyEma {
const fn new(alpha: f64) -> Self {
Self { value_ms: None, alpha }
}
fn update(&mut self, sample_ms: f64) {
self.value_ms = Some(match self.value_ms {
None => sample_ms,
Some(prev) => prev * (1.0 - self.alpha) + sample_ms * self.alpha,
});
}
fn get(&self) -> Option<f64> {
self.value_ms
}
}
// ============= Upstream State =============
#[derive(Debug)] #[derive(Debug)]
struct UpstreamState { struct UpstreamState {
config: UpstreamConfig, config: UpstreamConfig,
healthy: bool, healthy: bool,
fails: u32, fails: u32,
last_check: std::time::Instant, last_check: std::time::Instant,
/// Per-DC latency EMA (index 0 = DC1, index 4 = DC5)
dc_latency: [LatencyEma; NUM_DCS],
} }
impl UpstreamState {
fn new(config: UpstreamConfig) -> Self {
Self {
config,
healthy: true,
fails: 0,
last_check: std::time::Instant::now(),
dc_latency: [LatencyEma::new(0.3); NUM_DCS],
}
}
/// Convert dc_idx (1-based, may be negative) to array index 0..4
fn dc_array_idx(dc_idx: i16) -> Option<usize> {
let idx = (dc_idx.unsigned_abs() as usize).checked_sub(1)?;
if idx < NUM_DCS { Some(idx) } else { None }
}
/// Get latency for a specific DC, falling back to average across all known DCs
fn effective_latency(&self, dc_idx: Option<i16>) -> Option<f64> {
// Try DC-specific latency first
if let Some(di) = dc_idx.and_then(Self::dc_array_idx) {
if let Some(ms) = self.dc_latency[di].get() {
return Some(ms);
}
}
// Fallback: average of all known DC latencies
let (sum, count) = self.dc_latency.iter()
.filter_map(|l| l.get())
.fold((0.0, 0u32), |(s, c), v| (s + v, c + 1));
if count > 0 { Some(sum / count as f64) } else { None }
}
}
/// Result of a single DC ping
#[derive(Debug, Clone)]
pub struct DcPingResult {
pub dc_idx: usize,
pub dc_addr: SocketAddr,
pub rtt_ms: Option<f64>,
pub error: Option<String>,
}
/// Result of startup ping for one upstream
#[derive(Debug, Clone)]
pub struct StartupPingResult {
pub results: Vec<DcPingResult>,
pub upstream_name: String,
}
// ============= Upstream Manager =============
#[derive(Clone)] #[derive(Clone)]
pub struct UpstreamManager { pub struct UpstreamManager {
upstreams: Arc<RwLock<Vec<UpstreamState>>>, upstreams: Arc<RwLock<Vec<UpstreamState>>>,
@@ -30,12 +117,7 @@ impl UpstreamManager {
pub fn new(configs: Vec<UpstreamConfig>) -> Self { pub fn new(configs: Vec<UpstreamConfig>) -> Self {
let states = configs.into_iter() let states = configs.into_iter()
.filter(|c| c.enabled) .filter(|c| c.enabled)
.map(|c| UpstreamState { .map(UpstreamState::new)
config: c,
healthy: true, // Optimistic start
fails: 0,
last_check: std::time::Instant::now(),
})
.collect(); .collect();
Self { Self {
@@ -43,48 +125,78 @@ impl UpstreamManager {
} }
} }
/// Select an upstream using Weighted Round Robin (simplified) /// Select upstream using latency-weighted random selection.
async fn select_upstream(&self) -> Option<usize> { ///
/// `effective_weight = config_weight × latency_factor`
///
/// where `latency_factor = 1000 / latency_ms` if latency is known,
/// or `1.0` if no latency data is available.
///
/// This means a 50ms upstream gets factor 20, a 200ms upstream gets
/// factor 5 — the faster route is 4× more likely to be chosen
/// (all else being equal).
async fn select_upstream(&self, dc_idx: Option<i16>) -> Option<usize> {
let upstreams = self.upstreams.read().await; let upstreams = self.upstreams.read().await;
if upstreams.is_empty() { if upstreams.is_empty() {
return None; return None;
} }
let healthy_indices: Vec<usize> = upstreams.iter() let healthy: Vec<usize> = upstreams.iter()
.enumerate() .enumerate()
.filter(|(_, u)| u.healthy) .filter(|(_, u)| u.healthy)
.map(|(i, _)| i) .map(|(i, _)| i)
.collect(); .collect();
if healthy_indices.is_empty() { if healthy.is_empty() {
// If all unhealthy, try any random one // All unhealthy — pick any
return Some(rand::thread_rng().gen_range(0..upstreams.len())); return Some(rand::rng().gen_range(0..upstreams.len()));
} }
// Weighted selection if healthy.len() == 1 {
let total_weight: u32 = healthy_indices.iter() return Some(healthy[0]);
.map(|&i| upstreams[i].config.weight as u32)
.sum();
if total_weight == 0 {
return Some(healthy_indices[rand::thread_rng().gen_range(0..healthy_indices.len())]);
} }
let mut choice = rand::thread_rng().gen_range(0..total_weight); // Calculate latency-weighted scores
let weights: Vec<(usize, f64)> = healthy.iter().map(|&i| {
let base = upstreams[i].config.weight as f64;
let latency_factor = upstreams[i].effective_latency(dc_idx)
.map(|ms| if ms > 1.0 { 1000.0 / ms } else { 1000.0 })
.unwrap_or(1.0);
for &idx in &healthy_indices { (i, base * latency_factor)
let weight = upstreams[idx].config.weight as u32; }).collect();
let total: f64 = weights.iter().map(|(_, w)| w).sum();
if total <= 0.0 {
return Some(healthy[rand::rng().gen_range(0..healthy.len())]);
}
let mut choice: f64 = rand::rng().gen_range(0.0..total);
for &(idx, weight) in &weights {
if choice < weight { if choice < weight {
trace!(
upstream = idx,
dc = ?dc_idx,
weight = format!("{:.2}", weight),
total = format!("{:.2}", total),
"Upstream selected"
);
return Some(idx); return Some(idx);
} }
choice -= weight; choice -= weight;
} }
Some(healthy_indices[0]) Some(healthy[0])
} }
pub async fn connect(&self, target: SocketAddr) -> Result<TcpStream> { /// Connect to target through a selected upstream.
let idx = self.select_upstream().await ///
/// `dc_idx` is used for latency-based upstream selection and RTT tracking.
/// Pass `None` if DC index is unknown.
pub async fn connect(&self, target: SocketAddr, dc_idx: Option<i16>) -> Result<TcpStream> {
let idx = self.select_upstream(dc_idx).await
.ok_or_else(|| ProxyError::Config("No upstreams available".to_string()))?; .ok_or_else(|| ProxyError::Config("No upstreams available".to_string()))?;
let upstream = { let upstream = {
@@ -92,28 +204,34 @@ impl UpstreamManager {
guard[idx].config.clone() guard[idx].config.clone()
}; };
let start = Instant::now();
match self.connect_via_upstream(&upstream, target).await { match self.connect_via_upstream(&upstream, target).await {
Ok(stream) => { Ok(stream) => {
// Mark success let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
let mut guard = self.upstreams.write().await; let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(idx) { if let Some(u) = guard.get_mut(idx) {
if !u.healthy { if !u.healthy {
debug!("Upstream recovered: {:?}", u.config); debug!(rtt_ms = format!("{:.1}", rtt_ms), "Upstream recovered");
} }
u.healthy = true; u.healthy = true;
u.fails = 0; u.fails = 0;
// Store per-DC latency
if let Some(di) = dc_idx.and_then(UpstreamState::dc_array_idx) {
u.dc_latency[di].update(rtt_ms);
}
} }
Ok(stream) Ok(stream)
}, },
Err(e) => { Err(e) => {
// Mark failure
let mut guard = self.upstreams.write().await; let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(idx) { if let Some(u) = guard.get_mut(idx) {
u.fails += 1; u.fails += 1;
warn!("Failed to connect via upstream {:?}: {}. Fails: {}", u.config, e, u.fails); warn!(fails = u.fails, "Upstream failed: {}", e);
if u.fails > 3 { if u.fails > 3 {
u.healthy = false; u.healthy = false;
warn!("Upstream disabled due to failures: {:?}", u.config); warn!("Upstream marked unhealthy");
} }
} }
Err(e) Err(e)
@@ -129,18 +247,16 @@ impl UpstreamManager {
let socket = create_outgoing_socket_bound(target, bind_ip)?; let socket = create_outgoing_socket_bound(target, bind_ip)?;
// Non-blocking connect logic
socket.set_nonblocking(true)?; socket.set_nonblocking(true)?;
match socket.connect(&target.into()) { match socket.connect(&target.into()) {
Ok(()) => {}, Ok(()) => {},
Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {}, Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {},
Err(err) => return Err(ProxyError::Io(err)), Err(err) => return Err(ProxyError::Io(err)),
} }
let std_stream: std::net::TcpStream = socket.into(); let std_stream: std::net::TcpStream = socket.into();
let stream = TcpStream::from_std(std_stream)?; let stream = TcpStream::from_std(std_stream)?;
// Wait for connection to complete
stream.writable().await?; stream.writable().await?;
if let Some(e) = stream.take_error()? { if let Some(e) = stream.take_error()? {
return Err(ProxyError::Io(e)); return Err(ProxyError::Io(e));
@@ -149,8 +265,6 @@ impl UpstreamManager {
Ok(stream) Ok(stream)
}, },
UpstreamType::Socks4 { address, interface, user_id } => { UpstreamType::Socks4 { address, interface, user_id } => {
info!("Connecting to target {} via SOCKS4 proxy {}", target, address);
let proxy_addr: SocketAddr = address.parse() let proxy_addr: SocketAddr = address.parse()
.map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?; .map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?;
@@ -159,18 +273,16 @@ impl UpstreamManager {
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?; let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
// Non-blocking connect logic
socket.set_nonblocking(true)?; socket.set_nonblocking(true)?;
match socket.connect(&proxy_addr.into()) { match socket.connect(&proxy_addr.into()) {
Ok(()) => {}, Ok(()) => {},
Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {}, Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {},
Err(err) => return Err(ProxyError::Io(err)), Err(err) => return Err(ProxyError::Io(err)),
} }
let std_stream: std::net::TcpStream = socket.into(); let std_stream: std::net::TcpStream = socket.into();
let mut stream = TcpStream::from_std(std_stream)?; let mut stream = TcpStream::from_std(std_stream)?;
// Wait for connection to complete
stream.writable().await?; stream.writable().await?;
if let Some(e) = stream.take_error()? { if let Some(e) = stream.take_error()? {
return Err(ProxyError::Io(e)); return Err(ProxyError::Io(e));
@@ -180,8 +292,6 @@ impl UpstreamManager {
Ok(stream) Ok(stream)
}, },
UpstreamType::Socks5 { address, interface, username, password } => { UpstreamType::Socks5 { address, interface, username, password } => {
info!("Connecting to target {} via SOCKS5 proxy {}", target, address);
let proxy_addr: SocketAddr = address.parse() let proxy_addr: SocketAddr = address.parse()
.map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?; .map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?;
@@ -190,18 +300,16 @@ impl UpstreamManager {
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?; let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
// Non-blocking connect logic
socket.set_nonblocking(true)?; socket.set_nonblocking(true)?;
match socket.connect(&proxy_addr.into()) { match socket.connect(&proxy_addr.into()) {
Ok(()) => {}, Ok(()) => {},
Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {}, Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {},
Err(err) => return Err(ProxyError::Io(err)), Err(err) => return Err(ProxyError::Io(err)),
} }
let std_stream: std::net::TcpStream = socket.into(); let std_stream: std::net::TcpStream = socket.into();
let mut stream = TcpStream::from_std(std_stream)?; let mut stream = TcpStream::from_std(std_stream)?;
// Wait for connection to complete
stream.writable().await?; stream.writable().await?;
if let Some(e) = stream.take_error()? { if let Some(e) = stream.take_error()? {
return Err(ProxyError::Io(e)); return Err(ProxyError::Io(e));
@@ -213,13 +321,100 @@ impl UpstreamManager {
} }
} }
/// Background task to check health // ============= Startup Ping =============
pub async fn run_health_checks(&self) {
// Simple TCP connect check to a known stable DC (e.g. 149.154.167.50:443 - DC2) /// Ping all Telegram DCs through all upstreams.
let check_target: SocketAddr = "149.154.167.50:443".parse().unwrap(); pub async fn ping_all_dcs(&self, prefer_ipv6: bool) -> Vec<StartupPingResult> {
let upstreams: Vec<(usize, UpstreamConfig)> = {
let guard = self.upstreams.read().await;
guard.iter().enumerate()
.map(|(i, u)| (i, u.config.clone()))
.collect()
};
let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 };
let mut all_results = Vec::new();
for (upstream_idx, upstream_config) in &upstreams {
let upstream_name = match &upstream_config.upstream_type {
UpstreamType::Direct { interface } => {
format!("direct{}", interface.as_ref().map(|i| format!(" ({})", i)).unwrap_or_default())
}
UpstreamType::Socks4 { address, .. } => format!("socks4://{}", address),
UpstreamType::Socks5 { address, .. } => format!("socks5://{}", address),
};
let mut dc_results = Vec::new();
for (dc_zero_idx, dc_ip) in datacenters.iter().enumerate() {
let dc_addr = SocketAddr::new(*dc_ip, TG_DATACENTER_PORT);
let ping_result = tokio::time::timeout(
Duration::from_secs(5),
self.ping_single_dc(upstream_config, dc_addr)
).await;
let result = match ping_result {
Ok(Ok(rtt_ms)) => {
// Store per-DC latency
let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(*upstream_idx) {
u.dc_latency[dc_zero_idx].update(rtt_ms);
}
DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr,
rtt_ms: Some(rtt_ms),
error: None,
}
}
Ok(Err(e)) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr,
rtt_ms: None,
error: Some(e.to_string()),
},
Err(_) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr,
rtt_ms: None,
error: Some("timeout (5s)".to_string()),
},
};
dc_results.push(result);
}
all_results.push(StartupPingResult {
results: dc_results,
upstream_name,
});
}
all_results
}
async fn ping_single_dc(&self, config: &UpstreamConfig, target: SocketAddr) -> Result<f64> {
let start = Instant::now();
let _stream = self.connect_via_upstream(config, target).await?;
Ok(start.elapsed().as_secs_f64() * 1000.0)
}
// ============= Health Checks =============
/// Background health check: rotates through DCs, 30s interval.
pub async fn run_health_checks(&self, prefer_ipv6: bool) {
let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 };
let mut dc_rotation = 0usize;
loop { loop {
tokio::time::sleep(Duration::from_secs(60)).await; tokio::time::sleep(Duration::from_secs(30)).await;
let dc_zero_idx = dc_rotation % datacenters.len();
dc_rotation += 1;
let check_target = SocketAddr::new(datacenters[dc_zero_idx], TG_DATACENTER_PORT);
let count = self.upstreams.read().await.len(); let count = self.upstreams.read().await.len();
for i in 0..count { for i in 0..count {
@@ -228,6 +423,7 @@ impl UpstreamManager {
guard[i].config.clone() guard[i].config.clone()
}; };
let start = Instant::now();
let result = tokio::time::timeout( let result = tokio::time::timeout(
Duration::from_secs(10), Duration::from_secs(10),
self.connect_via_upstream(&config, check_target) self.connect_via_upstream(&config, check_target)
@@ -238,18 +434,36 @@ impl UpstreamManager {
match result { match result {
Ok(Ok(_stream)) => { Ok(Ok(_stream)) => {
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
u.dc_latency[dc_zero_idx].update(rtt_ms);
if !u.healthy { if !u.healthy {
debug!("Upstream recovered: {:?}", u.config); info!(
rtt = format!("{:.0}ms", rtt_ms),
dc = dc_zero_idx + 1,
"Upstream recovered"
);
} }
u.healthy = true; u.healthy = true;
u.fails = 0; u.fails = 0;
} }
Ok(Err(e)) => { Ok(Err(e)) => {
debug!("Health check failed for {:?}: {}", u.config, e); u.fails += 1;
// Don't mark unhealthy immediately in background check debug!(dc = dc_zero_idx + 1, fails = u.fails,
"Health check failed: {}", e);
if u.fails > 3 {
u.healthy = false;
warn!("Upstream unhealthy (fails)");
}
} }
Err(_) => { Err(_) => {
debug!("Health check timeout for {:?}", u.config); u.fails += 1;
debug!(dc = dc_zero_idx + 1, fails = u.fails,
"Health check timeout");
if u.fails > 3 {
u.healthy = false;
warn!("Upstream unhealthy (timeout)");
}
} }
} }
u.last_check = std::time::Instant::now(); u.last_check = std::time::Instant::now();