18 Commits

Author SHA1 Message Date
Alexey
7f8cde8317 NAT + STUN Probes... 2026-02-14 12:44:20 +03:00
Alexey
e32d8e6c7d ME Diagnostics 2026-02-14 04:19:44 +03:00
Alexey
d405756b94 HOL Minimized + Random conn_id + Target DC Magics
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-14 01:52:49 +03:00
Alexey
a8c3128c50 Middle Proxy Magics
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-14 01:51:10 +03:00
Alexey
70859aa5cf Middle Proxy is so real 2026-02-14 01:36:14 +03:00
Alexey
9b850b0bfb IP Version Superfallback 2026-02-14 00:30:09 +03:00
Alexey
de28655dd2 Middle Proxy Fixes
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-13 16:09:33 +03:00
Alexey
e62b41ae64 RPC Flags Fixes
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-13 14:28:47 +03:00
Alexey
f1c1f42de8 Key derivation + me_health_monitor + QuickACK
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-13 12:51:49 +03:00
Alexey
a494dfa9eb Middle Proxy Drafts
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-13 03:51:36 +03:00
Alexey
e6bf7ac40e Merge pull request #42 from telemt/codeql-tuning
Codeql tuning
2026-02-13 03:02:08 +03:00
Alexey
889a5fa19b Add mask_unix_sock for [censorship] masking: merge pull request #33 from Katze-942/main
Add mask_unix_sock for [censorship] masking
2026-02-12 21:30:51 +03:00
Жора Змейкин
d8ff958481 Add mask_unix_sock for censorship masking via Unix socket 2026-02-12 21:11:20 +03:00
Alexey
28ee74787b Merge pull request #36 from telemt/1.2.0.3
New Relay on Tokio Copy Bidirectional
2026-02-12 20:34:35 +03:00
Alexey
91eea914b3 Update codeql.yml 2026-02-12 19:00:12 +03:00
Alexey
3ba97a08fa Update codeql.yml 2026-02-12 18:58:42 +03:00
Alexey
6e445be108 CodeQL Tuning 2026-02-12 18:58:03 +03:00
Alexey
3c6752644a Create codeql.yml 2026-02-12 18:56:08 +03:00
29 changed files with 4049 additions and 615 deletions

45
.github/workflows/codeql.yml vendored Normal file
View File

@@ -0,0 +1,45 @@
name: "CodeQL Advanced"
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
schedule:
- cron: '0 0 * * 0'
jobs:
analyze:
name: Analyze (${{ matrix.language }})
runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }}
permissions:
security-events: write
packages: read
actions: read
contents: read
strategy:
fail-fast: false
matrix:
include:
- language: actions
build-mode: none
- language: rust
build-mode: none
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Initialize CodeQL
uses: github/codeql-action/init@v4
with:
languages: ${{ matrix.language }}
build-mode: ${{ matrix.build-mode }}
config-file: .github/codeql/codeql-config.yml
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v4
with:
category: "/language:${{ matrix.language }}"

View File

@@ -0,0 +1,20 @@
import rust
predicate isTestOnly(Item i) {
exists(ConditionalCompilation cc |
cc.getItem() = i and
cc.getCfg().toString() = "test"
)
}
predicate hasTestAttribute(Item i) {
exists(Attribute a |
a.getItem() = i and
a.getName() = "test"
)
}
predicate isProductionCode(Item i) {
not isTestOnly(i) and
not hasTestAttribute(i)
}

4
.github/workflows/queries/qlpack.yml vendored Normal file
View File

@@ -0,0 +1,4 @@
name: rust-production-only
version: 0.0.1
dependencies:
codeql/rust-all: "*"

View File

@@ -188,6 +188,7 @@ tls_domain = "petrovich.ru"
mask = true mask = true
mask_port = 443 mask_port = 443
# mask_host = "petrovich.ru" # Defaults to tls_domain if not set # mask_host = "petrovich.ru" # Defaults to tls_domain if not set
# mask_unix_sock = "/var/run/nginx.sock" # Unix socket (mutually exclusive with mask_host)
fake_cert_len = 2048 fake_cert_len = 2048
# === Access Control & Users === # === Access Control & Users ===

View File

@@ -44,10 +44,11 @@ client_ack = 300
# === Anti-Censorship & Masking === # === Anti-Censorship & Masking ===
[censorship] [censorship]
tls_domain = "google.ru" tls_domain = "petrovich.ru"
mask = true mask = true
mask_port = 443 mask_port = 443
# mask_host = "petrovich.ru" # Defaults to tls_domain if not set # mask_host = "petrovich.ru" # Defaults to tls_domain if not set
# mask_unix_sock = "/var/run/nginx.sock" # Unix socket (mutually exclusive with mask_host)
fake_cert_len = 2048 fake_cert_len = 2048
# === Access Control & Users === # === Access Control & Users ===
@@ -74,6 +75,6 @@ weight = 10
# [[upstreams]] # [[upstreams]]
# type = "socks5" # type = "socks5"
# address = "127.0.0.1:9050" # address = "127.0.0.1:1080"
# enabled = false # enabled = false
# weight = 1 # weight = 1

View File

@@ -1,32 +1,55 @@
//! Configuration //! Configuration
use crate::error::{ProxyError, Result};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::path::Path; use std::path::Path;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::error::{ProxyError, Result};
// ============= Helper Defaults ============= // ============= Helper Defaults =============
fn default_true() -> bool { true } fn default_true() -> bool {
fn default_port() -> u16 { 443 } true
fn default_tls_domain() -> String { "www.google.com".to_string() } }
fn default_mask_port() -> u16 { 443 } fn default_port() -> u16 {
fn default_replay_check_len() -> usize { 65536 } 443
fn default_replay_window_secs() -> u64 { 1800 } }
fn default_handshake_timeout() -> u64 { 15 } fn default_tls_domain() -> String {
fn default_connect_timeout() -> u64 { 10 } "www.google.com".to_string()
fn default_keepalive() -> u64 { 60 } }
fn default_ack_timeout() -> u64 { 300 } fn default_mask_port() -> u16 {
fn default_listen_addr() -> String { "0.0.0.0".to_string() } 443
fn default_fake_cert_len() -> usize { 2048 } }
fn default_weight() -> u16 { 1 } fn default_replay_check_len() -> usize {
65536
}
fn default_replay_window_secs() -> u64 {
1800
}
fn default_handshake_timeout() -> u64 {
15
}
fn default_connect_timeout() -> u64 {
10
}
fn default_keepalive() -> u64 {
60
}
fn default_ack_timeout() -> u64 {
300
}
fn default_listen_addr() -> String {
"0.0.0.0".to_string()
}
fn default_fake_cert_len() -> usize {
2048
}
fn default_weight() -> u16 {
1
}
fn default_metrics_whitelist() -> Vec<IpAddr> { fn default_metrics_whitelist() -> Vec<IpAddr> {
vec![ vec!["127.0.0.1".parse().unwrap(), "::1".parse().unwrap()]
"127.0.0.1".parse().unwrap(),
"::1".parse().unwrap(),
]
} }
// ============= Log Level ============= // ============= Log Level =============
@@ -96,7 +119,11 @@ pub struct ProxyModes {
impl Default for ProxyModes { impl Default for ProxyModes {
fn default() -> Self { fn default() -> Self {
Self { classic: true, secure: true, tls: true } Self {
classic: true,
secure: true,
tls: true,
}
} }
} }
@@ -117,6 +144,24 @@ pub struct GeneralConfig {
#[serde(default)] #[serde(default)]
pub ad_tag: Option<String>, pub ad_tag: Option<String>,
/// Path to proxy-secret binary file (auto-downloaded if absent).
/// Infrastructure secret from https://core.telegram.org/getProxySecret
#[serde(default)]
pub proxy_secret_path: Option<String>,
/// Public IP override for middle-proxy NAT environments.
/// When set, this IP is used in ME key derivation and RPC_PROXY_REQ "our_addr".
#[serde(default)]
pub middle_proxy_nat_ip: Option<IpAddr>,
/// Enable STUN-based NAT probing to discover public IP:port for ME KDF.
#[serde(default)]
pub middle_proxy_nat_probe: bool,
/// Optional STUN server address (host:port) for NAT probing.
#[serde(default)]
pub middle_proxy_nat_stun: Option<String>,
#[serde(default)] #[serde(default)]
pub log_level: LogLevel, pub log_level: LogLevel,
} }
@@ -129,6 +174,10 @@ impl Default for GeneralConfig {
fast_mode: true, fast_mode: true,
use_middle_proxy: false, use_middle_proxy: false,
ad_tag: None, ad_tag: None,
proxy_secret_path: None,
middle_proxy_nat_ip: None,
middle_proxy_nat_probe: false,
middle_proxy_nat_stun: None,
log_level: LogLevel::Normal, log_level: LogLevel::Normal,
} }
} }
@@ -212,6 +261,9 @@ pub struct AntiCensorshipConfig {
#[serde(default = "default_mask_port")] #[serde(default = "default_mask_port")]
pub mask_port: u16, pub mask_port: u16,
#[serde(default)]
pub mask_unix_sock: Option<String>,
#[serde(default = "default_fake_cert_len")] #[serde(default = "default_fake_cert_len")]
pub fake_cert_len: usize, pub fake_cert_len: usize,
} }
@@ -223,6 +275,7 @@ impl Default for AntiCensorshipConfig {
mask: true, mask: true,
mask_host: None, mask_host: None,
mask_port: default_mask_port(), mask_port: default_mask_port(),
mask_unix_sock: None,
fake_cert_len: default_fake_cert_len(), fake_cert_len: default_fake_cert_len(),
} }
} }
@@ -255,7 +308,10 @@ pub struct AccessConfig {
impl Default for AccessConfig { impl Default for AccessConfig {
fn default() -> Self { fn default() -> Self {
let mut users = HashMap::new(); let mut users = HashMap::new();
users.insert("default".to_string(), "00000000000000000000000000000000".to_string()); users.insert(
"default".to_string(),
"00000000000000000000000000000000".to_string(),
);
Self { Self {
users, users,
user_max_tcp_conns: HashMap::new(), user_max_tcp_conns: HashMap::new(),
@@ -355,11 +411,11 @@ pub struct ProxyConfig {
impl ProxyConfig { impl ProxyConfig {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> { pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let content = std::fs::read_to_string(path) let content =
.map_err(|e| ProxyError::Config(e.to_string()))?; std::fs::read_to_string(path).map_err(|e| ProxyError::Config(e.to_string()))?;
let mut config: ProxyConfig = toml::from_str(&content) let mut config: ProxyConfig =
.map_err(|e| ProxyError::Config(e.to_string()))?; toml::from_str(&content).map_err(|e| ProxyError::Config(e.to_string()))?;
// Validate secrets // Validate secrets
for (user, secret) in &config.access.users { for (user, secret) in &config.access.users {
@@ -376,8 +432,34 @@ impl ProxyConfig {
return Err(ProxyError::Config("tls_domain cannot be empty".to_string())); return Err(ProxyError::Config("tls_domain cannot be empty".to_string()));
} }
// Default mask_host to tls_domain if not set // Validate mask_unix_sock
if config.censorship.mask_host.is_none() { if let Some(ref sock_path) = config.censorship.mask_unix_sock {
if sock_path.is_empty() {
return Err(ProxyError::Config(
"mask_unix_sock cannot be empty".to_string(),
));
}
#[cfg(unix)]
if sock_path.len() > 107 {
return Err(ProxyError::Config(format!(
"mask_unix_sock path too long: {} bytes (max 107)",
sock_path.len()
)));
}
#[cfg(not(unix))]
return Err(ProxyError::Config(
"mask_unix_sock is only supported on Unix platforms".to_string(),
));
if config.censorship.mask_host.is_some() {
return Err(ProxyError::Config(
"mask_unix_sock and mask_host are mutually exclusive".to_string(),
));
}
}
// Default mask_host to tls_domain if not set and no unix socket configured
if config.censorship.mask_host.is_none() && config.censorship.mask_unix_sock.is_none() {
config.censorship.mask_host = Some(config.censorship.tls_domain.clone()); config.censorship.mask_host = Some(config.censorship.tls_domain.clone());
} }
@@ -394,7 +476,7 @@ impl ProxyConfig {
}); });
} }
if let Some(ipv6_str) = &config.server.listen_addr_ipv6 { if let Some(ipv6_str) = &config.server.listen_addr_ipv6 {
if let Ok(ipv6) = ipv6_str.parse::<IpAddr>() { if let Ok(ipv6) = ipv6_str.parse::<IpAddr>() {
config.server.listeners.push(ListenerConfig { config.server.listeners.push(ListenerConfig {
ip: ipv6, ip: ipv6,
announce_ip: None, announce_ip: None,
@@ -405,7 +487,7 @@ impl ProxyConfig {
// Migration: Populate upstreams if empty (Default Direct) // Migration: Populate upstreams if empty (Default Direct)
if config.upstreams.is_empty() { if config.upstreams.is_empty() {
config.upstreams.push(UpstreamConfig { config.upstreams.push(UpstreamConfig {
upstream_type: UpstreamType::Direct { interface: None }, upstream_type: UpstreamType::Direct { interface: None },
weight: 1, weight: 1,
enabled: true, enabled: true,
@@ -425,9 +507,10 @@ impl ProxyConfig {
} }
if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') { if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') {
return Err(ProxyError::Config( return Err(ProxyError::Config(format!(
format!("Invalid tls_domain: '{}'. Must be a valid domain name", self.censorship.tls_domain) "Invalid tls_domain: '{}'. Must be a valid domain name",
)); self.censorship.tls_domain
)));
} }
Ok(()) Ok(())

View File

@@ -55,12 +55,11 @@ pub fn crc32(data: &[u8]) -> u32 {
crc32fast::hash(data) crc32fast::hash(data)
} }
/// Middle Proxy key derivation /// Build the exact prekey buffer used by Telegram Middle Proxy KDF.
/// ///
/// Uses MD5 + SHA-1 as mandated by the Telegram Middle Proxy protocol. /// Returned buffer layout (IPv4):
/// These algorithms are NOT replaceable here changing them would break /// nonce_srv | nonce_clt | clt_ts | srv_ip | clt_port | purpose | clt_ip | srv_port | secret | nonce_srv | [clt_v6 | srv_v6] | nonce_clt
/// interoperability with Telegram's middle proxy infrastructure. pub fn build_middleproxy_prekey(
pub fn derive_middleproxy_keys(
nonce_srv: &[u8; 16], nonce_srv: &[u8; 16],
nonce_clt: &[u8; 16], nonce_clt: &[u8; 16],
clt_ts: &[u8; 4], clt_ts: &[u8; 4],
@@ -72,7 +71,7 @@ pub fn derive_middleproxy_keys(
secret: &[u8], secret: &[u8],
clt_ipv6: Option<&[u8; 16]>, clt_ipv6: Option<&[u8; 16]>,
srv_ipv6: Option<&[u8; 16]>, srv_ipv6: Option<&[u8; 16]>,
) -> ([u8; 32], [u8; 16]) { ) -> Vec<u8> {
const EMPTY_IP: [u8; 4] = [0, 0, 0, 0]; const EMPTY_IP: [u8; 4] = [0, 0, 0, 0];
let srv_ip = srv_ip.unwrap_or(&EMPTY_IP); let srv_ip = srv_ip.unwrap_or(&EMPTY_IP);
@@ -96,6 +95,40 @@ pub fn derive_middleproxy_keys(
} }
s.extend_from_slice(nonce_clt); s.extend_from_slice(nonce_clt);
s
}
/// Middle Proxy key derivation
///
/// Uses MD5 + SHA-1 as mandated by the Telegram Middle Proxy protocol.
/// These algorithms are NOT replaceable here — changing them would break
/// interoperability with Telegram's middle proxy infrastructure.
pub fn derive_middleproxy_keys(
nonce_srv: &[u8; 16],
nonce_clt: &[u8; 16],
clt_ts: &[u8; 4],
srv_ip: Option<&[u8]>,
clt_port: &[u8; 2],
purpose: &[u8],
clt_ip: Option<&[u8]>,
srv_port: &[u8; 2],
secret: &[u8],
clt_ipv6: Option<&[u8; 16]>,
srv_ipv6: Option<&[u8; 16]>,
) -> ([u8; 32], [u8; 16]) {
let s = build_middleproxy_prekey(
nonce_srv,
nonce_clt,
clt_ts,
srv_ip,
clt_port,
purpose,
clt_ip,
srv_port,
secret,
clt_ipv6,
srv_ipv6,
);
let md5_1 = md5(&s[1..]); let md5_1 = md5(&s[1..]);
let sha1_sum = sha1(&s); let sha1_sum = sha1(&s);
@@ -107,3 +140,39 @@ pub fn derive_middleproxy_keys(
(key, md5_2) (key, md5_2)
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn middleproxy_prekey_sha_is_stable() {
let nonce_srv = [0x11u8; 16];
let nonce_clt = [0x22u8; 16];
let clt_ts = 0x44332211u32.to_le_bytes();
let srv_ip = Some([149u8, 154, 175, 50].as_ref());
let clt_ip = Some([10u8, 0, 0, 1].as_ref());
let clt_port = 0x1f90u16.to_le_bytes(); // 8080
let srv_port = 0x22b8u16.to_le_bytes(); // 8888
let secret = vec![0x55u8; 128];
let prekey = build_middleproxy_prekey(
&nonce_srv,
&nonce_clt,
&clt_ts,
srv_ip,
&clt_port,
b"CLIENT",
clt_ip,
&srv_port,
&secret,
None,
None,
);
let digest = sha256(&prekey);
assert_eq!(
hex::encode(digest),
"a4595b75f1f610f2575ace802ddc65c91b5acef3b0e0d18189e0c7c9f787d15c"
);
}
}

View File

@@ -5,5 +5,5 @@ pub mod hash;
pub mod random; pub mod random;
pub use aes::{AesCtr, AesCbc}; pub use aes::{AesCtr, AesCbc};
pub use hash::{sha256, sha256_hmac, sha1, md5, crc32}; pub use hash::{sha256, sha256_hmac, sha1, md5, crc32, derive_middleproxy_keys, build_middleproxy_prekey};
pub use random::SecureRandom; pub use random::SecureRandom;

View File

@@ -1,4 +1,4 @@
//! Telemt - MTProxy on Rust //! telemt — Telegram MTProto Proxy
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
@@ -6,8 +6,8 @@ use std::time::Duration;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::signal; use tokio::signal;
use tokio::sync::Semaphore; use tokio::sync::Semaphore;
use tracing::{info, error, warn, debug}; use tracing::{debug, error, info, warn};
use tracing_subscriber::{fmt, EnvFilter, reload, prelude::*}; use tracing_subscriber::{EnvFilter, fmt, prelude::*, reload};
mod cli; mod cli;
mod config; mod config;
@@ -20,13 +20,14 @@ mod stream;
mod transport; mod transport;
mod util; mod util;
use crate::config::{ProxyConfig, LogLevel}; use crate::config::{LogLevel, ProxyConfig};
use crate::proxy::ClientHandler;
use crate::stats::{Stats, ReplayChecker};
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::transport::{create_listener, ListenOptions, UpstreamManager}; use crate::proxy::ClientHandler;
use crate::util::ip::detect_ip; use crate::stats::{ReplayChecker, Stats};
use crate::stream::BufferPool; use crate::stream::BufferPool;
use crate::transport::middle_proxy::MePool;
use crate::transport::{ListenOptions, UpstreamManager, create_listener};
use crate::util::ip::detect_ip;
fn parse_cli() -> (String, bool, Option<String>) { fn parse_cli() -> (String, bool, Option<String>) {
let mut config_path = "config.toml".to_string(); let mut config_path = "config.toml".to_string();
@@ -47,10 +48,14 @@ fn parse_cli() -> (String, bool, Option<String>) {
let mut i = 0; let mut i = 0;
while i < args.len() { while i < args.len() {
match args[i].as_str() { match args[i].as_str() {
"--silent" | "-s" => { silent = true; } "--silent" | "-s" => {
silent = true;
}
"--log-level" => { "--log-level" => {
i += 1; i += 1;
if i < args.len() { log_level = Some(args[i].clone()); } if i < args.len() {
log_level = Some(args[i].clone());
}
} }
s if s.starts_with("--log-level=") => { s if s.starts_with("--log-level=") => {
log_level = Some(s.trim_start_matches("--log-level=").to_string()); log_level = Some(s.trim_start_matches("--log-level=").to_string());
@@ -64,17 +69,27 @@ fn parse_cli() -> (String, bool, Option<String>) {
eprintln!(" --help, -h Show this help"); eprintln!(" --help, -h Show this help");
eprintln!(); eprintln!();
eprintln!("Setup (fire-and-forget):"); eprintln!("Setup (fire-and-forget):");
eprintln!(" --init Generate config, install systemd service, start"); eprintln!(
" --init Generate config, install systemd service, start"
);
eprintln!(" --port <PORT> Listen port (default: 443)"); eprintln!(" --port <PORT> Listen port (default: 443)");
eprintln!(" --domain <DOMAIN> TLS domain for masking (default: www.google.com)"); eprintln!(
eprintln!(" --secret <HEX> 32-char hex secret (auto-generated if omitted)"); " --domain <DOMAIN> TLS domain for masking (default: www.google.com)"
);
eprintln!(
" --secret <HEX> 32-char hex secret (auto-generated if omitted)"
);
eprintln!(" --user <NAME> Username (default: user)"); eprintln!(" --user <NAME> Username (default: user)");
eprintln!(" --config-dir <DIR> Config directory (default: /etc/telemt)"); eprintln!(" --config-dir <DIR> Config directory (default: /etc/telemt)");
eprintln!(" --no-start Don't start the service after install"); eprintln!(" --no-start Don't start the service after install");
std::process::exit(0); std::process::exit(0);
} }
s if !s.starts_with('-') => { config_path = s.to_string(); } s if !s.starts_with('-') => {
other => { eprintln!("Unknown option: {}", other); } config_path = s.to_string();
}
other => {
eprintln!("Unknown option: {}", other);
}
} }
i += 1; i += 1;
} }
@@ -83,7 +98,7 @@ fn parse_cli() -> (String, bool, Option<String>) {
} }
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let (config_path, cli_silent, cli_log_level) = parse_cli(); let (config_path, cli_silent, cli_log_level) = parse_cli();
let config = match ProxyConfig::load(&config_path) { let config = match ProxyConfig::load(&config_path) {
@@ -115,8 +130,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
config.general.log_level.clone() config.general.log_level.clone()
}; };
// Start with INFO so startup messages are always visible,
// then switch to user-configured level after startup
let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new("info")); let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new("info"));
tracing_subscriber::registry() tracing_subscriber::registry()
.with(filter_layer) .with(filter_layer)
@@ -125,21 +138,38 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION")); info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION"));
info!("Log level: {}", effective_log_level); info!("Log level: {}", effective_log_level);
info!("Modes: classic={} secure={} tls={}", info!(
config.general.modes.classic, "Modes: classic={} secure={} tls={}",
config.general.modes.secure, config.general.modes.classic, config.general.modes.secure, config.general.modes.tls
config.general.modes.tls); );
info!("TLS domain: {}", config.censorship.tls_domain); info!("TLS domain: {}", config.censorship.tls_domain);
info!("Mask: {} -> {}:{}", if let Some(ref sock) = config.censorship.mask_unix_sock {
config.censorship.mask, info!("Mask: {} -> unix:{}", config.censorship.mask, sock);
config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain), if !std::path::Path::new(sock).exists() {
config.censorship.mask_port); warn!(
"Unix socket '{}' does not exist yet. Masking will fail until it appears.",
sock
);
}
} else {
info!(
"Mask: {} -> {}:{}",
config.censorship.mask,
config
.censorship
.mask_host
.as_deref()
.unwrap_or(&config.censorship.tls_domain),
config.censorship.mask_port
);
}
if config.censorship.tls_domain == "www.google.com" { if config.censorship.tls_domain == "www.google.com" {
warn!("Using default tls_domain. Consider setting a custom domain."); warn!("Using default tls_domain. Consider setting a custom domain.");
} }
let prefer_ipv6 = config.general.prefer_ipv6; let prefer_ipv6 = config.general.prefer_ipv6;
let use_middle_proxy = config.general.use_middle_proxy;
let config = Arc::new(config); let config = Arc::new(config);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let rng = Arc::new(SecureRandom::new()); let rng = Arc::new(SecureRandom::new());
@@ -152,41 +182,193 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone())); let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096)); let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096));
// Connection concurrency limit — prevents OOM under SYN flood / connection storm. // Connection concurrency limit
// 10000 is generous; each connection uses ~64KB (2x 16KB relay buffers + overhead). let _max_connections = Arc::new(Semaphore::new(10_000));
// 10000 connections ≈ 640MB peak memory.
let max_connections = Arc::new(Semaphore::new(10_000));
// Startup DC ping // =====================================================================
info!("=== Telegram DC Connectivity ==="); // Middle Proxy initialization (if enabled)
let ping_results = upstream_manager.ping_all_dcs(prefer_ipv6).await; // =====================================================================
for upstream_result in &ping_results { let me_pool: Option<Arc<MePool>> = if use_middle_proxy {
info!(" via {}", upstream_result.upstream_name); info!("=== Middle Proxy Mode ===");
for dc in &upstream_result.results {
match (&dc.rtt_ms, &dc.error) { // ad_tag (proxy_tag) for advertising
(Some(rtt), _) => { let proxy_tag = config.general.ad_tag.as_ref().map(|tag| {
info!(" DC{} ({:>21}): {:.0}ms", dc.dc_idx, dc.dc_addr, rtt); hex::decode(tag).unwrap_or_else(|_| {
} warn!("Invalid ad_tag hex, middle proxy ad_tag will be empty");
(None, Some(err)) => { Vec::new()
info!(" DC{} ({:>21}): FAIL ({})", dc.dc_idx, dc.dc_addr, err); })
} });
_ => {
info!(" DC{} ({:>21}): FAIL", dc.dc_idx, dc.dc_addr); // =============================================================
// CRITICAL: Download Telegram proxy-secret (NOT user secret!)
//
// C MTProxy uses TWO separate secrets:
// -S flag = 16-byte user secret for client obfuscation
// --aes-pwd = 32-512 byte binary file for ME RPC auth
//
// proxy-secret is from: https://core.telegram.org/getProxySecret
// =============================================================
let proxy_secret_path = config.general.proxy_secret_path.as_deref();
match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).await {
Ok(proxy_secret) => {
info!(
secret_len = proxy_secret.len(),
key_sig = format_args!(
"0x{:08x}",
if proxy_secret.len() >= 4 {
u32::from_le_bytes([
proxy_secret[0],
proxy_secret[1],
proxy_secret[2],
proxy_secret[3],
])
} else {
0
}
),
"Proxy-secret loaded"
);
let pool = MePool::new(
proxy_tag,
proxy_secret,
config.general.middle_proxy_nat_ip,
config.general.middle_proxy_nat_probe,
config.general.middle_proxy_nat_stun.clone(),
);
match pool.init(2, &rng).await {
Ok(()) => {
info!("Middle-End pool initialized successfully");
// Phase 4: Start health monitor
let pool_clone = pool.clone();
let rng_clone = rng.clone();
tokio::spawn(async move {
crate::transport::middle_proxy::me_health_monitor(
pool_clone, rng_clone, 2,
)
.await;
});
Some(pool)
}
Err(e) => {
error!(error = %e, "Failed to initialize ME pool. Falling back to direct mode.");
None
}
} }
} }
Err(e) => {
error!(error = %e, "Failed to fetch proxy-secret. Falling back to direct mode.");
None
}
}
} else {
None
};
if me_pool.is_some() {
info!("Transport: Middle Proxy (supports all DCs including CDN)");
} else {
info!("Transport: Direct TCP (standard DCs only)");
}
// Startup DC ping (only meaningful in direct mode)
if me_pool.is_none() {
info!("================= Telegram DC Connectivity =================");
let ping_results = upstream_manager.ping_all_dcs(prefer_ipv6).await;
for upstream_result in &ping_results {
// Show which IP version is in use and which is fallback
if upstream_result.both_available {
if prefer_ipv6 {
info!(" IPv6 in use and IPv4 is fallback");
} else {
info!(" IPv4 in use and IPv6 is fallback");
}
} else {
let v6_works = upstream_result
.v6_results
.iter()
.any(|r| r.rtt_ms.is_some());
let v4_works = upstream_result
.v4_results
.iter()
.any(|r| r.rtt_ms.is_some());
if v6_works && !v4_works {
info!(" IPv6 only (IPv4 unavailable)");
} else if v4_works && !v6_works {
info!(" IPv4 only (IPv6 unavailable)");
} else if !v6_works && !v4_works {
info!(" No connectivity!");
}
}
info!(" via {}", upstream_result.upstream_name);
info!("============================================================");
// Print IPv6 results first
for dc in &upstream_result.v6_results {
let addr_str = format!("{}:{}", dc.dc_addr.ip(), dc.dc_addr.port());
match &dc.rtt_ms {
Some(rtt) => {
// Align: IPv6 addresses are longer, use fewer tabs
// [2001:b28:f23d:f001::a]:443 = ~28 chars
info!(" DC{} [IPv6] {}:\t\t{:.0} ms", dc.dc_idx, addr_str, rtt);
}
None => {
let err = dc.error.as_deref().unwrap_or("fail");
info!(" DC{} [IPv6] {}:\t\tFAIL ({})", dc.dc_idx, addr_str, err);
}
}
}
info!("============================================================");
// Print IPv4 results
for dc in &upstream_result.v4_results {
let addr_str = format!("{}:{}", dc.dc_addr.ip(), dc.dc_addr.port());
match &dc.rtt_ms {
Some(rtt) => {
// Align: IPv4 addresses are shorter, use more tabs
// 149.154.175.50:443 = ~18 chars
info!(
" DC{} [IPv4] {}:\t\t\t\t{:.0} ms",
dc.dc_idx, addr_str, rtt
);
}
None => {
let err = dc.error.as_deref().unwrap_or("fail");
info!(
" DC{} [IPv4] {}:\t\t\t\tFAIL ({})",
dc.dc_idx, addr_str, err
);
}
}
}
info!("============================================================");
} }
} }
info!("================================");
// Background tasks // Background tasks
let um_clone = upstream_manager.clone(); let um_clone = upstream_manager.clone();
tokio::spawn(async move { um_clone.run_health_checks(prefer_ipv6).await; }); tokio::spawn(async move {
um_clone.run_health_checks(prefer_ipv6).await;
});
let rc_clone = replay_checker.clone(); let rc_clone = replay_checker.clone();
tokio::spawn(async move { rc_clone.run_periodic_cleanup().await; }); tokio::spawn(async move {
rc_clone.run_periodic_cleanup().await;
});
let detected_ip = detect_ip().await; let detected_ip = detect_ip().await;
debug!("Detected IPs: v4={:?} v6={:?}", detected_ip.ipv4, detected_ip.ipv6); debug!(
"Detected IPs: v4={:?} v6={:?}",
detected_ip.ipv4, detected_ip.ipv6
);
let mut listeners = Vec::new(); let mut listeners = Vec::new();
@@ -220,17 +402,23 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
if let Some(secret) = config.access.users.get(user_name) { if let Some(secret) = config.access.users.get(user_name) {
info!("User: {}", user_name); info!("User: {}", user_name);
if config.general.modes.classic { if config.general.modes.classic {
info!(" Classic: tg://proxy?server={}&port={}&secret={}", info!(
public_ip, config.server.port, secret); " Classic: tg://proxy?server={}&port={}&secret={}",
public_ip, config.server.port, secret
);
} }
if config.general.modes.secure { if config.general.modes.secure {
info!(" DD: tg://proxy?server={}&port={}&secret=dd{}", info!(
public_ip, config.server.port, secret); " DD: tg://proxy?server={}&port={}&secret=dd{}",
public_ip, config.server.port, secret
);
} }
if config.general.modes.tls { if config.general.modes.tls {
let domain_hex = hex::encode(&config.censorship.tls_domain); let domain_hex = hex::encode(&config.censorship.tls_domain);
info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", info!(
public_ip, config.server.port, secret, domain_hex); " EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
public_ip, config.server.port, secret, domain_hex
);
} }
} else { } else {
warn!("User '{}' in show_link not found", user_name); warn!("User '{}' in show_link not found", user_name);
@@ -240,7 +428,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
} }
listeners.push(listener); listeners.push(listener);
}, }
Err(e) => { Err(e) => {
error!("Failed to bind to {}: {}", addr, e); error!("Failed to bind to {}: {}", addr, e);
} }
@@ -258,7 +446,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
} else { } else {
EnvFilter::new(effective_log_level.to_filter_str()) EnvFilter::new(effective_log_level.to_filter_str())
}; };
filter_handle.reload(runtime_filter).expect("Failed to switch log filter"); filter_handle
.reload(runtime_filter)
.expect("Failed to switch log filter");
for listener in listeners { for listener in listeners {
let config = config.clone(); let config = config.clone();
@@ -267,6 +457,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let replay_checker = replay_checker.clone(); let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone(); let buffer_pool = buffer_pool.clone();
let rng = rng.clone(); let rng = rng.clone();
let me_pool = me_pool.clone();
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
@@ -278,12 +469,23 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let replay_checker = replay_checker.clone(); let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone(); let buffer_pool = buffer_pool.clone();
let rng = rng.clone(); let rng = rng.clone();
let me_pool = me_pool.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = ClientHandler::new( if let Err(e) = ClientHandler::new(
stream, peer_addr, config, stats, stream,
upstream_manager, replay_checker, buffer_pool, rng peer_addr,
).run().await { config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
me_pool,
)
.run()
.await
{
debug!(peer = %peer_addr, error = %e, "Connection error"); debug!(peer = %peer_addr, error = %e, "Connection error");
} }
}); });

View File

@@ -202,6 +202,17 @@ pub static RESERVED_NONCE_CONTINUES: &[[u8; 4]] = &[
// ============= RPC Constants (for Middle Proxy) ============= // ============= RPC Constants (for Middle Proxy) =============
/// RPC Proxy Request /// RPC Proxy Request
/// RPC Flags (from Erlang mtp_rpc.erl)
pub const RPC_FLAG_NOT_ENCRYPTED: u32 = 0x2;
pub const RPC_FLAG_HAS_AD_TAG: u32 = 0x8;
pub const RPC_FLAG_MAGIC: u32 = 0x1000;
pub const RPC_FLAG_EXTMODE2: u32 = 0x20000;
pub const RPC_FLAG_PAD: u32 = 0x8000000;
pub const RPC_FLAG_INTERMEDIATE: u32 = 0x20000000;
pub const RPC_FLAG_ABRIDGED: u32 = 0x40000000;
pub const RPC_FLAG_QUICKACK: u32 = 0x80000000;
pub const RPC_PROXY_REQ: [u8; 4] = [0xee, 0xf1, 0xce, 0x36]; pub const RPC_PROXY_REQ: [u8; 4] = [0xee, 0xf1, 0xce, 0x36];
/// RPC Proxy Answer /// RPC Proxy Answer
pub const RPC_PROXY_ANS: [u8; 4] = [0x0d, 0xda, 0x03, 0x44]; pub const RPC_PROXY_ANS: [u8; 4] = [0x0d, 0xda, 0x03, 0x44];
@@ -228,7 +239,56 @@ pub mod rpc_flags {
pub const FLAG_QUICKACK: u32 = 0x80000000; pub const FLAG_QUICKACK: u32 = 0x80000000;
} }
#[cfg(test)]
// ============= Middle-End Proxy Servers =============
pub const ME_PROXY_PORT: u16 = 8888;
pub static TG_MIDDLE_PROXIES_FLAT_V4: LazyLock<Vec<(IpAddr, u16)>> = LazyLock::new(|| {
vec![
(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888),
(IpAddr::V4(Ipv4Addr::new(149, 154, 161, 144)), 8888),
(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 100)), 8888),
(IpAddr::V4(Ipv4Addr::new(91, 108, 4, 136)), 8888),
(IpAddr::V4(Ipv4Addr::new(91, 108, 56, 183)), 8888),
]
});
// ============= RPC Constants (u32 native endian) =============
// From mtproto-common.h + net-tcp-rpc-common.h + mtproto-proxy.c
pub const RPC_NONCE_U32: u32 = 0x7acb87aa;
pub const RPC_HANDSHAKE_U32: u32 = 0x7682eef5;
pub const RPC_HANDSHAKE_ERROR_U32: u32 = 0x6a27beda;
pub const TL_PROXY_TAG_U32: u32 = 0xdb1e26ae; // mtproto-proxy.c:121
// mtproto-common.h
pub const RPC_PROXY_REQ_U32: u32 = 0x36cef1ee;
pub const RPC_PROXY_ANS_U32: u32 = 0x4403da0d;
pub const RPC_CLOSE_CONN_U32: u32 = 0x1fcf425d;
pub const RPC_CLOSE_EXT_U32: u32 = 0x5eb634a2;
pub const RPC_SIMPLE_ACK_U32: u32 = 0x3bac409b;
pub const RPC_PING_U32: u32 = 0x5730a2df;
pub const RPC_PONG_U32: u32 = 0x8430eaa7;
pub const RPC_CRYPTO_NONE_U32: u32 = 0;
pub const RPC_CRYPTO_AES_U32: u32 = 1;
pub mod proxy_flags {
pub const FLAG_HAS_AD_TAG: u32 = 1;
pub const FLAG_NOT_ENCRYPTED: u32 = 0x2;
pub const FLAG_HAS_AD_TAG2: u32 = 0x8;
pub const FLAG_MAGIC: u32 = 0x1000;
pub const FLAG_EXTMODE2: u32 = 0x20000;
pub const FLAG_PAD: u32 = 0x8000000;
pub const FLAG_INTERMEDIATE: u32 = 0x20000000;
pub const FLAG_ABRIDGED: u32 = 0x40000000;
pub const FLAG_QUICKACK: u32 = 0x80000000;
}
pub const ME_CONNECT_TIMEOUT_SECS: u64 = 5;
pub const ME_HANDSHAKE_TIMEOUT_SECS: u64 = 10;
#[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@@ -3,26 +3,25 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tokio::time::timeout; use tokio::time::timeout;
use tracing::{debug, info, warn, error, trace}; use tracing::{debug, warn};
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::error::{ProxyError, Result, HandshakeResult}; use crate::crypto::SecureRandom;
use crate::error::{HandshakeResult, ProxyError, Result};
use crate::protocol::constants::*; use crate::protocol::constants::*;
use crate::protocol::tls; use crate::protocol::tls;
use crate::stats::{Stats, ReplayChecker}; use crate::stats::{ReplayChecker, Stats};
use crate::transport::{configure_client_socket, UpstreamManager}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter, BufferPool}; use crate::transport::middle_proxy::MePool;
use crate::crypto::{AesCtr, SecureRandom}; use crate::transport::{UpstreamManager, configure_client_socket};
use crate::proxy::handshake::{ use crate::proxy::direct_relay::handle_via_direct;
handle_tls_handshake, handle_mtproto_handshake, use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake};
HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce,
};
use crate::proxy::relay::relay_bidirectional;
use crate::proxy::masking::handle_bad_client; use crate::proxy::masking::handle_bad_client;
use crate::proxy::middle_relay::handle_via_middle_proxy;
pub struct ClientHandler; pub struct ClientHandler;
@@ -35,6 +34,7 @@ pub struct RunningClientHandler {
upstream_manager: Arc<UpstreamManager>, upstream_manager: Arc<UpstreamManager>,
buffer_pool: Arc<BufferPool>, buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>, rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
} }
impl ClientHandler { impl ClientHandler {
@@ -47,10 +47,18 @@ impl ClientHandler {
replay_checker: Arc<ReplayChecker>, replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>, buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>, rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
) -> RunningClientHandler { ) -> RunningClientHandler {
RunningClientHandler { RunningClientHandler {
stream, peer, config, stats, replay_checker, stream,
upstream_manager, buffer_pool, rng, peer,
config,
stats,
replay_checker,
upstream_manager,
buffer_pool,
rng,
me_pool,
} }
} }
} }
@@ -132,12 +140,20 @@ impl RunningClientHandler {
let stats = self.stats.clone(); let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone(); let buffer_pool = self.buffer_pool.clone();
let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?;
let (read_half, write_half) = self.stream.into_split(); let (read_half, write_half) = self.stream.into_split();
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
&handshake, read_half, write_half, peer, &handshake,
&config, &replay_checker, &self.rng, read_half,
).await { write_half,
peer,
&config,
&replay_checker,
&self.rng,
)
.await
{
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => { HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad(); stats.increment_connects_bad();
@@ -149,15 +165,26 @@ impl RunningClientHandler {
debug!(peer = %peer, "Reading MTProto handshake through TLS"); debug!(peer = %peer, "Reading MTProto handshake through TLS");
let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?; let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?;
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into() let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..]
.try_into()
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?; .map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
&mtproto_handshake, tls_reader, tls_writer, peer, &mtproto_handshake,
&config, &replay_checker, true, tls_reader,
).await { tls_writer,
peer,
&config,
&replay_checker,
true,
)
.await
{
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader: _, writer: _ } => { HandshakeResult::BadClient {
reader: _,
writer: _,
} => {
stats.increment_connects_bad(); stats.increment_connects_bad();
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
return Ok(()); return Ok(());
@@ -166,10 +193,18 @@ impl RunningClientHandler {
}; };
Self::handle_authenticated_static( Self::handle_authenticated_static(
crypto_reader, crypto_writer, success, crypto_reader,
self.upstream_manager, self.stats, self.config, crypto_writer,
buffer_pool, self.rng, success,
).await self.upstream_manager,
self.stats,
self.config,
buffer_pool,
self.rng,
self.me_pool,
local_addr,
)
.await
} }
async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> { async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
@@ -192,12 +227,20 @@ impl RunningClientHandler {
let stats = self.stats.clone(); let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone(); let buffer_pool = self.buffer_pool.clone();
let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?;
let (read_half, write_half) = self.stream.into_split(); let (read_half, write_half) = self.stream.into_split();
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
&handshake, read_half, write_half, peer, &handshake,
&config, &replay_checker, false, read_half,
).await { write_half,
peer,
&config,
&replay_checker,
false,
)
.await
{
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => { HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad(); stats.increment_connects_bad();
@@ -208,12 +251,24 @@ impl RunningClientHandler {
}; };
Self::handle_authenticated_static( Self::handle_authenticated_static(
crypto_reader, crypto_writer, success, crypto_reader,
self.upstream_manager, self.stats, self.config, crypto_writer,
buffer_pool, self.rng, success,
).await self.upstream_manager,
self.stats,
self.config,
buffer_pool,
self.rng,
self.me_pool,
local_addr,
)
.await
} }
/// Main dispatch after successful handshake.
/// Two modes:
/// - Direct: TCP relay to TG DC (existing behavior)
/// - Middle Proxy: RPC multiplex through ME pool (new — supports CDN DCs)
async fn handle_authenticated_static<R, W>( async fn handle_authenticated_static<R, W>(
client_reader: CryptoReader<R>, client_reader: CryptoReader<R>,
client_writer: CryptoWriter<W>, client_writer: CryptoWriter<W>,
@@ -223,6 +278,8 @@ impl RunningClientHandler {
config: Arc<ProxyConfig>, config: Arc<ProxyConfig>,
buffer_pool: Arc<BufferPool>, buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>, rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
local_addr: SocketAddr,
) -> Result<()> ) -> Result<()>
where where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
@@ -235,168 +292,63 @@ impl RunningClientHandler {
return Err(e); return Err(e);
} }
let dc_addr = Self::get_dc_addr_static(success.dc_idx, &config)?; // Decide: middle proxy or direct
if config.general.use_middle_proxy {
info!( if let Some(ref pool) = me_pool {
user = %user, return handle_via_middle_proxy(
peer = %success.peer, client_reader,
dc = success.dc_idx, client_writer,
dc_addr = %dc_addr, success,
proto = ?success.proto_tag, pool.clone(),
"Connecting to Telegram" stats,
); config,
buffer_pool,
// Pass dc_idx for latency-based upstream selection local_addr,
let tg_stream = upstream_manager.connect(dc_addr, Some(success.dc_idx)).await?; )
.await;
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake"); }
warn!("use_middle_proxy=true but MePool not initialized, falling back to direct");
let (tg_reader, tg_writer) = Self::do_tg_handshake_static(
tg_stream, &success, &config, rng.as_ref(),
).await?;
debug!(peer = %success.peer, "TG handshake complete, starting relay");
stats.increment_user_connects(user);
stats.increment_user_curr_connects(user);
let relay_result = relay_bidirectional(
client_reader, client_writer,
tg_reader, tg_writer,
user, Arc::clone(&stats), buffer_pool,
).await;
stats.decrement_user_curr_connects(user);
match &relay_result {
Ok(()) => debug!(user = %user, "Relay completed"),
Err(e) => debug!(user = %user, error = %e, "Relay ended with error"),
} }
relay_result // Direct mode (original behavior)
handle_via_direct(
client_reader,
client_writer,
success,
upstream_manager,
stats,
config,
buffer_pool,
rng,
)
.await
} }
fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> { fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> {
if let Some(expiration) = config.access.user_expirations.get(user) { if let Some(expiration) = config.access.user_expirations.get(user) {
if chrono::Utc::now() > *expiration { if chrono::Utc::now() > *expiration {
return Err(ProxyError::UserExpired { user: user.to_string() }); return Err(ProxyError::UserExpired {
user: user.to_string(),
});
} }
} }
if let Some(limit) = config.access.user_max_tcp_conns.get(user) { if let Some(limit) = config.access.user_max_tcp_conns.get(user) {
if stats.get_user_curr_connects(user) >= *limit as u64 { if stats.get_user_curr_connects(user) >= *limit as u64 {
return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() }); return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(),
});
} }
} }
if let Some(quota) = config.access.user_data_quota.get(user) { if let Some(quota) = config.access.user_data_quota.get(user) {
if stats.get_user_total_octets(user) >= *quota { if stats.get_user_total_octets(user) >= *quota {
return Err(ProxyError::DataQuotaExceeded { user: user.to_string() }); return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
} }
} }
Ok(()) Ok(())
} }
/// Resolve DC index to a target address.
///
/// Matches the C implementation's behavior exactly:
///
/// 1. Look up DC in known clusters (standard DCs ±1..±5)
/// 2. If not found and `force=1` → fall back to `default_cluster`
///
/// In the C code:
/// - `proxy-multi.conf` is downloaded from Telegram, contains only DC ±1..±5
/// - `default 2;` directive sets the default cluster
/// - `mf_cluster_lookup(CurConf, target_dc, 1)` returns default_cluster
/// for any unknown DC (like CDN DC 203)
///
/// So DC 203, DC 101, DC -300, etc. all route to the default DC (2).
/// There is NO modular arithmetic in the C implementation.
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
let datacenters = if config.general.prefer_ipv6 {
&*TG_DATACENTERS_V6
} else {
&*TG_DATACENTERS_V4
};
let num_dcs = datacenters.len(); // 5
// === Step 1: Check dc_overrides (like C's `proxy_for <dc> <ip>:<port>`) ===
let dc_key = dc_idx.to_string();
if let Some(addr_str) = config.dc_overrides.get(&dc_key) {
match addr_str.parse::<SocketAddr>() {
Ok(addr) => {
debug!(dc_idx = dc_idx, addr = %addr, "Using DC override from config");
return Ok(addr);
}
Err(_) => {
warn!(dc_idx = dc_idx, addr_str = %addr_str,
"Invalid DC override address in config, ignoring");
}
}
}
// === Step 2: Standard DCs ±1..±5 — direct lookup ===
let abs_dc = dc_idx.unsigned_abs() as usize;
if abs_dc >= 1 && abs_dc <= num_dcs {
return Ok(SocketAddr::new(datacenters[abs_dc - 1], TG_DATACENTER_PORT));
}
// === Step 3: Unknown DC — fall back to default_cluster ===
// Exactly like C's `mf_cluster_lookup(CurConf, target_dc, force=1)`
// which returns `MC->default_cluster` when the DC is not found.
// Telegram's proxy-multi.conf uses `default 2;`
let default_dc = config.default_dc.unwrap_or(2) as usize;
let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs {
default_dc - 1
} else {
1 // DC 2 (index 1) — matches Telegram's `default 2;`
};
info!(
original_dc = dc_idx,
fallback_dc = (fallback_idx + 1) as u16,
fallback_addr = %datacenters[fallback_idx],
"Special DC ---> default_cluster"
);
Ok(SocketAddr::new(datacenters[fallback_idx], TG_DATACENTER_PORT))
}
async fn do_tg_handshake_static(
mut stream: TcpStream,
success: &HandshakeSuccess,
config: &ProxyConfig,
rng: &SecureRandom,
) -> Result<(CryptoReader<tokio::net::tcp::OwnedReadHalf>, CryptoWriter<tokio::net::tcp::OwnedWriteHalf>)> {
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce(
success.proto_tag,
&success.dec_key,
success.dec_iv,
rng,
config.general.fast_mode,
);
let encrypted_nonce = encrypt_tg_nonce(&nonce);
debug!(
peer = %success.peer,
nonce_head = %hex::encode(&nonce[..16]),
"Sending nonce to Telegram"
);
stream.write_all(&encrypted_nonce).await?;
stream.flush().await?;
let (read_half, write_half) = stream.into_split();
let decryptor = AesCtr::new(&tg_dec_key, tg_dec_iv);
let encryptor = AesCtr::new(&tg_enc_key, tg_enc_iv);
Ok((
CryptoReader::new(read_half, decryptor),
CryptoWriter::new(write_half, encryptor),
))
}
} }

163
src/proxy/direct_relay.rs Normal file
View File

@@ -0,0 +1,163 @@
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::{debug, info, warn};
use crate::config::ProxyConfig;
use crate::crypto::SecureRandom;
use crate::error::Result;
use crate::protocol::constants::*;
use crate::proxy::handshake::{HandshakeSuccess, encrypt_tg_nonce_with_ciphers, generate_tg_nonce};
use crate::proxy::relay::relay_bidirectional;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::UpstreamManager;
pub(crate) async fn handle_via_direct<R, W>(
client_reader: CryptoReader<R>,
client_writer: CryptoWriter<W>,
success: HandshakeSuccess,
upstream_manager: Arc<UpstreamManager>,
stats: Arc<Stats>,
config: Arc<ProxyConfig>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
) -> Result<()>
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
let user = &success.user;
let dc_addr = get_dc_addr_static(success.dc_idx, &config)?;
info!(
user = %user,
peer = %success.peer,
dc = success.dc_idx,
dc_addr = %dc_addr,
proto = ?success.proto_tag,
mode = "direct",
"Connecting to Telegram DC"
);
let tg_stream = upstream_manager
.connect(dc_addr, Some(success.dc_idx))
.await?;
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake");
let (tg_reader, tg_writer) =
do_tg_handshake_static(tg_stream, &success, &config, rng.as_ref()).await?;
debug!(peer = %success.peer, "TG handshake complete, starting relay");
stats.increment_user_connects(user);
stats.increment_user_curr_connects(user);
let relay_result = relay_bidirectional(
client_reader,
client_writer,
tg_reader,
tg_writer,
user,
Arc::clone(&stats),
buffer_pool,
)
.await;
stats.decrement_user_curr_connects(user);
match &relay_result {
Ok(()) => debug!(user = %user, "Direct relay completed"),
Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"),
}
relay_result
}
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
let datacenters = if config.general.prefer_ipv6 {
&*TG_DATACENTERS_V6
} else {
&*TG_DATACENTERS_V4
};
let num_dcs = datacenters.len();
let dc_key = dc_idx.to_string();
if let Some(addr_str) = config.dc_overrides.get(&dc_key) {
match addr_str.parse::<SocketAddr>() {
Ok(addr) => {
debug!(dc_idx = dc_idx, addr = %addr, "Using DC override from config");
return Ok(addr);
}
Err(_) => {
warn!(dc_idx = dc_idx, addr_str = %addr_str,
"Invalid DC override address in config, ignoring");
}
}
}
let abs_dc = dc_idx.unsigned_abs() as usize;
if abs_dc >= 1 && abs_dc <= num_dcs {
return Ok(SocketAddr::new(datacenters[abs_dc - 1], TG_DATACENTER_PORT));
}
let default_dc = config.default_dc.unwrap_or(2) as usize;
let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs {
default_dc - 1
} else {
1
};
info!(
original_dc = dc_idx,
fallback_dc = (fallback_idx + 1) as u16,
fallback_addr = %datacenters[fallback_idx],
"Special DC ---> default_cluster"
);
Ok(SocketAddr::new(
datacenters[fallback_idx],
TG_DATACENTER_PORT,
))
}
async fn do_tg_handshake_static(
mut stream: TcpStream,
success: &HandshakeSuccess,
config: &ProxyConfig,
rng: &SecureRandom,
) -> Result<(
CryptoReader<tokio::net::tcp::OwnedReadHalf>,
CryptoWriter<tokio::net::tcp::OwnedWriteHalf>,
)> {
let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce(
success.proto_tag,
success.dc_idx,
&success.dec_key,
success.dec_iv,
rng,
config.general.fast_mode,
);
let (encrypted_nonce, tg_encryptor, tg_decryptor) = encrypt_tg_nonce_with_ciphers(&nonce);
debug!(
peer = %success.peer,
nonce_head = %hex::encode(&nonce[..16]),
"Sending nonce to Telegram"
);
stream.write_all(&encrypted_nonce).await?;
stream.flush().await?;
let (read_half, write_half) = stream.into_split();
Ok((
CryptoReader::new(read_half, tg_decryptor),
CryptoWriter::new(write_half, tg_encryptor),
))
}

View File

@@ -218,7 +218,6 @@ where
replay_checker.add_handshake(dec_prekey_iv); replay_checker.add_handshake(dec_prekey_iv);
let decryptor = AesCtr::new(&dec_key, dec_iv);
let encryptor = AesCtr::new(&enc_key, enc_iv); let encryptor = AesCtr::new(&enc_key, enc_iv);
let success = HandshakeSuccess { let success = HandshakeSuccess {
@@ -256,6 +255,7 @@ where
/// Generate nonce for Telegram connection /// Generate nonce for Telegram connection
pub fn generate_tg_nonce( pub fn generate_tg_nonce(
proto_tag: ProtoTag, proto_tag: ProtoTag,
dc_idx: i16,
client_dec_key: &[u8; 32], client_dec_key: &[u8; 32],
client_dec_iv: u128, client_dec_iv: u128,
rng: &SecureRandom, rng: &SecureRandom,
@@ -274,6 +274,8 @@ pub fn generate_tg_nonce(
if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; } if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; }
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
// CRITICAL: write dc_idx so upstream DC knows where to route
nonce[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
if fast_mode { if fast_mode {
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN].copy_from_slice(client_dec_key); nonce[SKIP_LEN..SKIP_LEN + KEY_LEN].copy_from_slice(client_dec_key);
@@ -294,19 +296,32 @@ pub fn generate_tg_nonce(
} }
} }
/// Encrypt nonce for sending to Telegram /// Encrypt nonce for sending to Telegram and return cipher objects with correct counter state
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> { pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec<u8>, AesCtr, AesCtr) {
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap(); let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
let iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
let mut encryptor = AesCtr::new(&key, iv); let enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
let encrypted_full = encryptor.encrypt(nonce); let enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
let dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap();
let dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap());
let mut encryptor = AesCtr::new(&enc_key, enc_iv);
let encrypted_full = encryptor.encrypt(nonce); // counter: 0 → 4
let mut result = nonce[..PROTO_TAG_POS].to_vec(); let mut result = nonce[..PROTO_TAG_POS].to_vec();
result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]); result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]);
result let decryptor = AesCtr::new(&dec_key, dec_iv);
(result, encryptor, decryptor)
}
/// Encrypt nonce for sending to Telegram (legacy function for compatibility)
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(nonce);
encrypted
} }
#[cfg(test)] #[cfg(test)]
@@ -320,7 +335,7 @@ mod tests {
let rng = SecureRandom::new(); let rng = SecureRandom::new();
let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) =
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, &rng, false); generate_tg_nonce(ProtoTag::Secure, 2, &client_dec_key, client_dec_iv, &rng, false);
assert_eq!(nonce.len(), HANDSHAKE_LEN); assert_eq!(nonce.len(), HANDSHAKE_LEN);
@@ -335,7 +350,7 @@ mod tests {
let rng = SecureRandom::new(); let rng = SecureRandom::new();
let (nonce, _, _, _, _) = let (nonce, _, _, _, _) =
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, &rng, false); generate_tg_nonce(ProtoTag::Secure, 2, &client_dec_key, client_dec_iv, &rng, false);
let encrypted = encrypt_tg_nonce(&nonce); let encrypted = encrypt_tg_nonce(&nonce);

View File

@@ -3,6 +3,8 @@
use std::time::Duration; use std::time::Duration;
use std::str; use std::str;
use tokio::net::TcpStream; use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tokio::time::timeout; use tokio::time::timeout;
use tracing::debug; use tracing::debug;
@@ -45,8 +47,8 @@ fn detect_client_type(data: &[u8]) -> &'static str {
/// Handle a bad client by forwarding to mask host /// Handle a bad client by forwarding to mask host
pub async fn handle_bad_client<R, W>( pub async fn handle_bad_client<R, W>(
mut reader: R, reader: R,
mut writer: W, writer: W,
initial_data: &[u8], initial_data: &[u8],
config: &ProxyConfig, config: &ProxyConfig,
) )
@@ -62,6 +64,34 @@ where
let client_type = detect_client_type(initial_data); let client_type = detect_client_type(initial_data);
// Connect via Unix socket or TCP
#[cfg(unix)]
if let Some(ref sock_path) = config.censorship.mask_unix_sock {
debug!(
client_type = client_type,
sock = %sock_path,
data_len = initial_data.len(),
"Forwarding bad client to mask unix socket"
);
let connect_result = timeout(MASK_TIMEOUT, UnixStream::connect(sock_path)).await;
match connect_result {
Ok(Ok(stream)) => {
let (mask_read, mask_write) = stream.into_split();
relay_to_mask(reader, writer, mask_read, mask_write, initial_data).await;
}
Ok(Err(e)) => {
debug!(error = %e, "Failed to connect to mask unix socket");
consume_client_data(reader).await;
}
Err(_) => {
debug!("Timeout connecting to mask unix socket");
consume_client_data(reader).await;
}
}
return;
}
let mask_host = config.censorship.mask_host.as_deref() let mask_host = config.censorship.mask_host.as_deref()
.unwrap_or(&config.censorship.tls_domain); .unwrap_or(&config.censorship.tls_domain);
let mask_port = config.censorship.mask_port; let mask_port = config.censorship.mask_port;
@@ -76,27 +106,37 @@ where
// Connect to mask host // Connect to mask host
let mask_addr = format!("{}:{}", mask_host, mask_port); let mask_addr = format!("{}:{}", mask_host, mask_port);
let connect_result = timeout( let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
MASK_TIMEOUT, match connect_result {
TcpStream::connect(&mask_addr) Ok(Ok(stream)) => {
).await; let (mask_read, mask_write) = stream.into_split();
relay_to_mask(reader, writer, mask_read, mask_write, initial_data).await;
let mask_stream = match connect_result { }
Ok(Ok(s)) => s,
Ok(Err(e)) => { Ok(Err(e)) => {
debug!(error = %e, "Failed to connect to mask host"); debug!(error = %e, "Failed to connect to mask host");
consume_client_data(reader).await; consume_client_data(reader).await;
return;
} }
Err(_) => { Err(_) => {
debug!("Timeout connecting to mask host"); debug!("Timeout connecting to mask host");
consume_client_data(reader).await; consume_client_data(reader).await;
return;
} }
}; }
}
let (mut mask_read, mut mask_write) = mask_stream.into_split();
/// Relay traffic between client and mask backend
async fn relay_to_mask<R, W, MR, MW>(
mut reader: R,
mut writer: W,
mut mask_read: MR,
mut mask_write: MW,
initial_data: &[u8],
)
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
MR: AsyncRead + Unpin + Send + 'static,
MW: AsyncWrite + Unpin + Send + 'static,
{
// Send initial data to mask host // Send initial data to mask host
if mask_write.write_all(initial_data).await.is_err() { if mask_write.write_all(initial_data).await.is_err() {
return; return;

254
src/proxy/middle_relay.rs Normal file
View File

@@ -0,0 +1,254 @@
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tracing::{debug, info, trace};
use crate::config::ProxyConfig;
use crate::error::{ProxyError, Result};
use crate::protocol::constants::*;
use crate::proxy::handshake::HandshakeSuccess;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
pub(crate) async fn handle_via_middle_proxy<R, W>(
mut crypto_reader: CryptoReader<R>,
mut crypto_writer: CryptoWriter<W>,
success: HandshakeSuccess,
me_pool: Arc<MePool>,
stats: Arc<Stats>,
_config: Arc<ProxyConfig>,
_buffer_pool: Arc<BufferPool>,
local_addr: SocketAddr,
) -> Result<()>
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
let user = success.user.clone();
let peer = success.peer;
let proto_tag = success.proto_tag;
info!(
user = %user,
peer = %peer,
dc = success.dc_idx,
proto = ?proto_tag,
mode = "middle_proxy",
"Routing via Middle-End"
);
let (conn_id, mut me_rx) = me_pool.registry().register().await;
stats.increment_user_connects(&user);
stats.increment_user_curr_connects(&user);
let proto_flags = proto_flags_for_tag(proto_tag, me_pool.has_proxy_tag());
debug!(
user = %user,
conn_id,
proto_flags = format_args!("0x{:08x}", proto_flags),
"ME relay started"
);
let translated_local_addr = me_pool.translate_our_addr(local_addr);
let result: Result<()> = loop {
tokio::select! {
client_frame = read_client_payload(&mut crypto_reader, proto_tag) => {
match client_frame {
Ok(Some(payload)) => {
trace!(conn_id, bytes = payload.len(), "C->ME frame");
stats.add_user_octets_from(&user, payload.len() as u64);
me_pool.send_proxy_req(
conn_id,
success.dc_idx,
peer,
translated_local_addr,
&payload,
proto_flags,
).await?;
}
Ok(None) => {
debug!(conn_id, "Client EOF");
let _ = me_pool.send_close(conn_id).await;
break Ok(());
}
Err(e) => break Err(e),
}
}
me_msg = me_rx.recv() => {
match me_msg {
Some(MeResponse::Data { flags, data }) => {
trace!(conn_id, bytes = data.len(), flags, "ME->C data");
stats.add_user_octets_to(&user, data.len() as u64);
write_client_payload(&mut crypto_writer, proto_tag, flags, &data).await?;
}
Some(MeResponse::Ack(confirm)) => {
trace!(conn_id, confirm, "ME->C quickack");
write_client_ack(&mut crypto_writer, proto_tag, confirm).await?;
}
Some(MeResponse::Close) => {
debug!(conn_id, "ME sent close");
break Ok(());
}
None => {
debug!(conn_id, "ME channel closed");
break Err(ProxyError::Proxy("ME connection lost".into()));
}
}
}
}
};
debug!(user = %user, conn_id, "ME relay cleanup");
me_pool.registry().unregister(conn_id).await;
stats.decrement_user_curr_connects(&user);
result
}
async fn read_client_payload<R>(
client_reader: &mut CryptoReader<R>,
proto_tag: ProtoTag,
) -> Result<Option<Vec<u8>>>
where
R: AsyncRead + Unpin + Send + 'static,
{
let len = match proto_tag {
ProtoTag::Abridged => {
let mut first = [0u8; 1];
match client_reader.read_exact(&mut first).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ProxyError::Io(e)),
}
let len_words = if (first[0] & 0x7f) == 0x7f {
let mut ext = [0u8; 3];
client_reader
.read_exact(&mut ext)
.await
.map_err(ProxyError::Io)?;
u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize
} else {
(first[0] & 0x7f) as usize
};
len_words
.checked_mul(4)
.ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))?
}
ProtoTag::Intermediate | ProtoTag::Secure => {
let mut len_buf = [0u8; 4];
match client_reader.read_exact(&mut len_buf).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ProxyError::Io(e)),
}
(u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize
}
};
if len > 16 * 1024 * 1024 {
return Err(ProxyError::Proxy(format!("Frame too large: {len}")));
}
let mut payload = vec![0u8; len];
client_reader
.read_exact(&mut payload)
.await
.map_err(ProxyError::Io)?;
Ok(Some(payload))
}
async fn write_client_payload<W>(
client_writer: &mut CryptoWriter<W>,
proto_tag: ProtoTag,
flags: u32,
data: &[u8],
) -> Result<()>
where
W: AsyncWrite + Unpin + Send + 'static,
{
let quickack = (flags & RPC_FLAG_QUICKACK) != 0;
match proto_tag {
ProtoTag::Abridged => {
if data.len() % 4 != 0 {
return Err(ProxyError::Proxy(format!(
"Abridged payload must be 4-byte aligned, got {}",
data.len()
)));
}
let len_words = data.len() / 4;
if len_words < 0x7f {
let mut first = len_words as u8;
if quickack {
first |= 0x80;
}
client_writer
.write_all(&[first])
.await
.map_err(ProxyError::Io)?;
} else if len_words < (1 << 24) {
let mut first = 0x7fu8;
if quickack {
first |= 0x80;
}
let lw = (len_words as u32).to_le_bytes();
client_writer
.write_all(&[first, lw[0], lw[1], lw[2]])
.await
.map_err(ProxyError::Io)?;
} else {
return Err(ProxyError::Proxy(format!(
"Abridged frame too large: {}",
data.len()
)));
}
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
}
ProtoTag::Intermediate | ProtoTag::Secure => {
let mut len = data.len() as u32;
if quickack {
len |= 0x8000_0000;
}
client_writer
.write_all(&len.to_le_bytes())
.await
.map_err(ProxyError::Io)?;
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
}
}
client_writer.flush().await.map_err(ProxyError::Io)
}
async fn write_client_ack<W>(
client_writer: &mut CryptoWriter<W>,
proto_tag: ProtoTag,
confirm: u32,
) -> Result<()>
where
W: AsyncWrite + Unpin + Send + 'static,
{
let bytes = if proto_tag == ProtoTag::Abridged {
confirm.to_be_bytes()
} else {
confirm.to_le_bytes()
};
client_writer
.write_all(&bytes)
.await
.map_err(ProxyError::Io)?;
client_writer.flush().await.map_err(ProxyError::Io)
}

View File

@@ -1,11 +1,13 @@
//! Proxy Defs //! Proxy Defs
pub mod handshake;
pub mod client; pub mod client;
pub mod relay; pub mod direct_relay;
pub mod handshake;
pub mod masking; pub mod masking;
pub mod middle_relay;
pub mod relay;
pub use handshake::*;
pub use client::ClientHandler; pub use client::ClientHandler;
pub use relay::*; pub use handshake::*;
pub use masking::*; pub use masking::*;
pub use relay::*;

View File

@@ -0,0 +1,925 @@
//! Middle Proxy RPC Transport
//!
//! Implements Telegram Middle-End RPC protocol for routing to ALL DCs (including CDN).
//!
//! ## Phase 3 fixes:
//! - ROOT CAUSE: Use Telegram proxy-secret (binary file) not user secret
//! - Streaming handshake response (no fixed-size read deadlock)
//! - Health monitoring + reconnection
//! - Hex diagnostics for debugging
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use bytes::{Bytes, BytesMut};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::{mpsc, Mutex, RwLock};
use tokio::time::{timeout, Instant};
use tracing::{debug, info, trace, warn, error};
use crate::crypto::{crc32, derive_middleproxy_keys, AesCbc, SecureRandom};
use crate::error::{ProxyError, Result};
use crate::protocol::constants::*;
// ========== Proxy Secret Fetching ==========
/// Fetch the Telegram proxy-secret binary file.
///
/// This is NOT the user secret (-S flag, 16 bytes hex for clients).
/// This is the infrastructure secret (--aes-pwd in C MTProxy),
/// a binary file of 32-512 bytes used for ME RPC key derivation.
///
/// Strategy: try local cache, then download from Telegram.
pub async fn fetch_proxy_secret(cache_path: Option<&str>) -> Result<Vec<u8>> {
let cache = cache_path.unwrap_or("proxy-secret");
// 1. Try local cache (< 24h old)
if let Ok(metadata) = tokio::fs::metadata(cache).await {
if let Ok(modified) = metadata.modified() {
let age = std::time::SystemTime::now()
.duration_since(modified)
.unwrap_or(Duration::from_secs(u64::MAX));
if age < Duration::from_secs(86400) {
if let Ok(data) = tokio::fs::read(cache).await {
if data.len() >= 32 {
info!(
path = cache,
len = data.len(),
age_hours = age.as_secs() / 3600,
"Loaded proxy-secret from cache"
);
return Ok(data);
}
warn!(path = cache, len = data.len(), "Cached proxy-secret too short");
}
}
}
}
// 2. Download from Telegram
info!("Downloading proxy-secret from core.telegram.org...");
let data = download_proxy_secret().await?;
// 3. Cache locally (best-effort)
if let Err(e) = tokio::fs::write(cache, &data).await {
warn!(error = %e, "Failed to cache proxy-secret (non-fatal)");
} else {
debug!(path = cache, len = data.len(), "Cached proxy-secret");
}
Ok(data)
}
async fn download_proxy_secret() -> Result<Vec<u8>> {
let url = "https://core.telegram.org/getProxySecret";
let resp = reqwest::get(url)
.await
.map_err(|e| ProxyError::Proxy(format!("Failed to download proxy-secret: {}", e)))?;
if !resp.status().is_success() {
return Err(ProxyError::Proxy(format!(
"proxy-secret download HTTP {}", resp.status()
)));
}
let data = resp.bytes().await
.map_err(|e| ProxyError::Proxy(format!("Read proxy-secret body: {}", e)))?
.to_vec();
if data.len() < 32 {
return Err(ProxyError::Proxy(format!(
"proxy-secret too short: {} bytes (need >= 32)", data.len()
)));
}
info!(len = data.len(), "Downloaded proxy-secret OK");
Ok(data)
}
// ========== RPC Frame helpers ==========
/// Build an RPC frame: [len(4) | seq_no(4) | payload | crc32(4)]
fn build_rpc_frame(seq_no: i32, payload: &[u8]) -> Vec<u8> {
let total_len = (4 + 4 + payload.len() + 4) as u32;
let mut f = Vec::with_capacity(total_len as usize);
f.extend_from_slice(&total_len.to_le_bytes());
f.extend_from_slice(&seq_no.to_le_bytes());
f.extend_from_slice(payload);
let c = crc32(&f);
f.extend_from_slice(&c.to_le_bytes());
f
}
/// Read one plaintext RPC frame. Returns (seq_no, payload).
async fn read_rpc_frame_plaintext(
rd: &mut (impl AsyncReadExt + Unpin),
) -> Result<(i32, Vec<u8>)> {
let mut len_buf = [0u8; 4];
rd.read_exact(&mut len_buf).await.map_err(ProxyError::Io)?;
let total_len = u32::from_le_bytes(len_buf) as usize;
if total_len < 12 || total_len > (1 << 24) {
return Err(ProxyError::InvalidHandshake(
format!("Bad RPC frame length: {}", total_len),
));
}
let mut rest = vec![0u8; total_len - 4];
rd.read_exact(&mut rest).await.map_err(ProxyError::Io)?;
let mut full = Vec::with_capacity(total_len);
full.extend_from_slice(&len_buf);
full.extend_from_slice(&rest);
let crc_offset = total_len - 4;
let expected_crc = u32::from_le_bytes([
full[crc_offset], full[crc_offset + 1],
full[crc_offset + 2], full[crc_offset + 3],
]);
let actual_crc = crc32(&full[..crc_offset]);
if expected_crc != actual_crc {
return Err(ProxyError::InvalidHandshake(
format!("CRC mismatch: 0x{:08x} vs 0x{:08x}", expected_crc, actual_crc),
));
}
let seq_no = i32::from_le_bytes([full[4], full[5], full[6], full[7]]);
let payload = full[8..crc_offset].to_vec();
Ok((seq_no, payload))
}
// ========== RPC Nonce (32 bytes payload) ==========
fn build_nonce_payload(key_selector: u32, crypto_ts: u32, nonce: &[u8; 16]) -> [u8; 32] {
let mut p = [0u8; 32];
p[0..4].copy_from_slice(&RPC_NONCE_U32.to_le_bytes());
p[4..8].copy_from_slice(&key_selector.to_le_bytes());
p[8..12].copy_from_slice(&RPC_CRYPTO_AES_U32.to_le_bytes());
p[12..16].copy_from_slice(&crypto_ts.to_le_bytes());
p[16..32].copy_from_slice(nonce);
p
}
fn parse_nonce_payload(d: &[u8]) -> Result<(u32, u32, [u8; 16])> {
if d.len() < 32 {
return Err(ProxyError::InvalidHandshake(
format!("Nonce payload too short: {} bytes", d.len()),
));
}
let t = u32::from_le_bytes([d[0], d[1], d[2], d[3]]);
if t != RPC_NONCE_U32 {
return Err(ProxyError::InvalidHandshake(
format!("Expected RPC_NONCE 0x{:08x}, got 0x{:08x}", RPC_NONCE_U32, t),
));
}
let schema = u32::from_le_bytes([d[8], d[9], d[10], d[11]]);
let ts = u32::from_le_bytes([d[12], d[13], d[14], d[15]]);
let mut nonce = [0u8; 16];
nonce.copy_from_slice(&d[16..32]);
Ok((schema, ts, nonce))
}
// ========== RPC Handshake (32 bytes payload) ==========
fn build_handshake_payload(our_ip: u32, our_port: u16, peer_ip: u32, peer_port: u16) -> [u8; 32] {
let mut p = [0u8; 32];
p[0..4].copy_from_slice(&RPC_HANDSHAKE_U32.to_le_bytes());
// flags = 0 at offset 4..8
// sender_pid: {ip(4), port(2), pid(2), utime(4)} at offset 8..20
p[8..12].copy_from_slice(&our_ip.to_le_bytes());
p[12..14].copy_from_slice(&our_port.to_le_bytes());
let pid = (std::process::id() & 0xFFFF) as u16;
p[14..16].copy_from_slice(&pid.to_le_bytes());
let utime = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as u32;
p[16..20].copy_from_slice(&utime.to_le_bytes());
// peer_pid: {ip(4), port(2), pid(2), utime(4)} at offset 20..32
p[20..24].copy_from_slice(&peer_ip.to_le_bytes());
p[24..26].copy_from_slice(&peer_port.to_le_bytes());
p
}
// ========== CBC helpers ==========
fn cbc_encrypt_padded(key: &[u8; 32], iv: &[u8; 16], plaintext: &[u8]) -> Result<(Vec<u8>, [u8; 16])> {
let pad = (16 - (plaintext.len() % 16)) % 16;
let mut buf = plaintext.to_vec();
let pad_pattern: [u8; 4] = [0x04, 0x00, 0x00, 0x00];
for i in 0..pad {
buf.push(pad_pattern[i % 4]);
}
let cipher = AesCbc::new(*key, *iv);
cipher.encrypt_in_place(&mut buf)
.map_err(|e| ProxyError::Crypto(format!("CBC encrypt: {}", e)))?;
let mut new_iv = [0u8; 16];
if buf.len() >= 16 {
new_iv.copy_from_slice(&buf[buf.len() - 16..]);
}
Ok((buf, new_iv))
}
fn cbc_decrypt_inplace(key: &[u8; 32], iv: &[u8; 16], data: &mut [u8]) -> Result<[u8; 16]> {
let mut new_iv = [0u8; 16];
if data.len() >= 16 {
new_iv.copy_from_slice(&data[data.len() - 16..]);
}
AesCbc::new(*key, *iv)
.decrypt_in_place(data)
.map_err(|e| ProxyError::Crypto(format!("CBC decrypt: {}", e)))?;
Ok(new_iv)
}
// ========== IPv4 helpers ==========
fn ipv4_to_mapped_v6(ip: Ipv4Addr) -> [u8; 16] {
let mut buf = [0u8; 16];
buf[10] = 0xFF;
buf[11] = 0xFF;
let o = ip.octets();
buf[12] = o[0]; buf[13] = o[1]; buf[14] = o[2]; buf[15] = o[3];
buf
}
fn addr_to_ip_u32(addr: &SocketAddr) -> u32 {
match addr.ip() {
IpAddr::V4(v4) => u32::from_be_bytes(v4.octets()),
IpAddr::V6(v6) => {
if let Some(v4) = v6.to_ipv4_mapped() {
u32::from_be_bytes(v4.octets())
} else { 0 }
}
}
}
// ========== ME Response ==========
#[derive(Debug)]
pub enum MeResponse {
Data(Bytes),
Ack(u32),
Close,
}
// ========== Connection Registry ==========
pub struct ConnRegistry {
map: RwLock<HashMap<u64, mpsc::Sender<MeResponse>>>,
next_id: AtomicU64,
}
impl ConnRegistry {
pub fn new() -> Self {
Self {
map: RwLock::new(HashMap::new()),
next_id: AtomicU64::new(1),
}
}
pub async fn register(&self) -> (u64, mpsc::Receiver<MeResponse>) {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = mpsc::channel(256);
self.map.write().await.insert(id, tx);
(id, rx)
}
pub async fn unregister(&self, id: u64) {
self.map.write().await.remove(&id);
}
pub async fn route(&self, id: u64, resp: MeResponse) -> bool {
let m = self.map.read().await;
if let Some(tx) = m.get(&id) {
tx.send(resp).await.is_ok()
} else { false }
}
}
// ========== RPC Writer (streaming CBC) ==========
struct RpcWriter {
writer: tokio::io::WriteHalf<TcpStream>,
key: [u8; 32],
iv: [u8; 16],
seq_no: i32,
}
impl RpcWriter {
async fn send(&mut self, payload: &[u8]) -> Result<()> {
let frame = build_rpc_frame(self.seq_no, payload);
self.seq_no += 1;
let pad = (16 - (frame.len() % 16)) % 16;
let mut buf = frame;
let pad_pattern: [u8; 4] = [0x04, 0x00, 0x00, 0x00];
for i in 0..pad {
buf.push(pad_pattern[i % 4]);
}
let cipher = AesCbc::new(self.key, self.iv);
cipher.encrypt_in_place(&mut buf)
.map_err(|e| ProxyError::Crypto(format!("{}", e)))?;
if buf.len() >= 16 {
self.iv.copy_from_slice(&buf[buf.len() - 16..]);
}
self.writer.write_all(&buf).await.map_err(ProxyError::Io)
}
}
// ========== RPC_PROXY_REQ ==========
fn build_proxy_req_payload(
conn_id: u64,
client_addr: SocketAddr,
our_addr: SocketAddr,
data: &[u8],
proxy_tag: Option<&[u8]>,
proto_flags: u32,
) -> Vec<u8> {
// flags are pre-calculated by proto_flags_for_tag
// We just need to ensure FLAG_HAS_AD_TAG is set if we have a tag (it is set by default in our new function, but let's be safe)
let mut flags = proto_flags;
// The C code logic:
// flags = (transport_flags) | 0x1000 | 0x20000 | 0x8 (if tag)
// Our proto_flags_for_tag returns: 0x8 | 0x1000 | 0x20000 | transport_flags
// So we are good.
let b_cap = 128 + data.len();
let mut b = Vec::with_capacity(b_cap);
b.extend_from_slice(&RPC_PROXY_REQ_U32.to_le_bytes());
b.extend_from_slice(&flags.to_le_bytes());
b.extend_from_slice(&conn_id.to_le_bytes());
// Client IP (16 bytes IPv4-mapped-v6) + port (4 bytes)
match client_addr.ip() {
IpAddr::V4(v4) => b.extend_from_slice(&ipv4_to_mapped_v6(v4)),
IpAddr::V6(v6) => b.extend_from_slice(&v6.octets()),
}
b.extend_from_slice(&(client_addr.port() as u32).to_le_bytes());
// Our IP (16 bytes) + port (4 bytes)
match our_addr.ip() {
IpAddr::V4(v4) => b.extend_from_slice(&ipv4_to_mapped_v6(v4)),
IpAddr::V6(v6) => b.extend_from_slice(&v6.octets()),
}
b.extend_from_slice(&(our_addr.port() as u32).to_le_bytes());
// Extra section (proxy_tag)
if flags & 12 != 0 {
let extra_start = b.len();
b.extend_from_slice(&0u32.to_le_bytes()); // placeholder
if let Some(tag) = proxy_tag {
b.extend_from_slice(&TL_PROXY_TAG_U32.to_le_bytes());
// TL string encoding
if tag.len() < 254 {
b.push(tag.len() as u8);
b.extend_from_slice(tag);
let pad = (4 - ((1 + tag.len()) % 4)) % 4;
b.extend(std::iter::repeat(0u8).take(pad));
} else {
b.push(0xfe);
let len_bytes = (tag.len() as u32).to_le_bytes();
b.extend_from_slice(&len_bytes[..3]);
b.extend_from_slice(tag);
let pad = (4 - (tag.len() % 4)) % 4;
b.extend(std::iter::repeat(0u8).take(pad));
}
}
let extra_bytes = (b.len() - extra_start - 4) as u32;
let eb = extra_bytes.to_le_bytes();
b[extra_start..extra_start + 4].copy_from_slice(&eb);
}
b.extend_from_slice(data);
b
}
// ========== ME Pool ==========
pub struct MePool {
registry: Arc<ConnRegistry>,
writers: Arc<RwLock<Vec<Arc<Mutex<RpcWriter>>>>>,
rr: AtomicU64,
proxy_tag: Option<Vec<u8>>,
/// Telegram proxy-secret (binary, 32-512 bytes)
proxy_secret: Vec<u8>,
pool_size: usize,
}
impl MePool {
pub fn new(proxy_tag: Option<Vec<u8>>, proxy_secret: Vec<u8>) -> Arc<Self> {
Arc::new(Self {
registry: Arc::new(ConnRegistry::new()),
writers: Arc::new(RwLock::new(Vec::new())),
rr: AtomicU64::new(0),
proxy_tag,
proxy_secret,
pool_size: 2,
})
}
pub fn registry(&self) -> &Arc<ConnRegistry> {
&self.registry
}
fn writers_arc(&self) -> Arc<RwLock<Vec<Arc<Mutex<RpcWriter>>>>> {
self.writers.clone()
}
/// key_selector = first 4 bytes of proxy-secret as LE u32
/// C: main_secret.key_signature via union { char secret[]; int key_signature; }
fn key_selector(&self) -> u32 {
if self.proxy_secret.len() >= 4 {
u32::from_le_bytes([
self.proxy_secret[0], self.proxy_secret[1],
self.proxy_secret[2], self.proxy_secret[3],
])
} else { 0 }
}
pub async fn init(
self: &Arc<Self>,
pool_size: usize,
rng: &SecureRandom,
) -> Result<()> {
let addrs = &*TG_MIDDLE_PROXIES_FLAT_V4;
let ks = self.key_selector();
info!(
me_servers = addrs.len(),
pool_size,
key_selector = format_args!("0x{:08x}", ks),
secret_len = self.proxy_secret.len(),
"Initializing ME pool"
);
for &(ip, port) in addrs.iter() {
for i in 0..pool_size {
let addr = SocketAddr::new(ip, port);
match self.connect_one(addr, rng).await {
Ok(()) => info!(%addr, idx = i, "ME connected"),
Err(e) => warn!(%addr, idx = i, error = %e, "ME connect failed"),
}
}
if self.writers.read().await.len() >= pool_size {
break;
}
}
if self.writers.read().await.is_empty() {
return Err(ProxyError::Proxy("No ME connections".into()));
}
Ok(())
}
async fn connect_one(
self: &Arc<Self>,
addr: SocketAddr,
rng: &SecureRandom,
) -> Result<()> {
let secret = &self.proxy_secret;
if secret.len() < 32 {
return Err(ProxyError::Proxy("proxy-secret too short for ME auth".into()));
}
// ===== TCP connect =====
let stream = timeout(
Duration::from_secs(ME_CONNECT_TIMEOUT_SECS),
TcpStream::connect(addr),
)
.await
.map_err(|_| ProxyError::ConnectionTimeout { addr: addr.to_string() })?
.map_err(ProxyError::Io)?;
stream.set_nodelay(true).ok();
let local_addr = stream.local_addr().map_err(ProxyError::Io)?;
let peer_addr = stream.peer_addr().map_err(ProxyError::Io)?;
let (mut rd, mut wr) = tokio::io::split(stream);
// ===== 1. Send RPC nonce (plaintext, seq=-2) =====
let my_nonce: [u8; 16] = rng.bytes(16).try_into().unwrap();
let crypto_ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as u32;
let ks = self.key_selector();
let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce);
let nonce_frame = build_rpc_frame(-2, &nonce_payload);
debug!(
%addr,
frame_len = nonce_frame.len(),
key_sel = format_args!("0x{:08x}", ks),
crypto_ts,
"Sending nonce"
);
wr.write_all(&nonce_frame).await.map_err(ProxyError::Io)?;
wr.flush().await.map_err(ProxyError::Io)?;
// ===== 2. Read server nonce (plaintext, seq=-2) =====
let (srv_seq, srv_nonce_payload) = timeout(
Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS),
read_rpc_frame_plaintext(&mut rd),
)
.await
.map_err(|_| ProxyError::TgHandshakeTimeout)??;
if srv_seq != -2 {
return Err(ProxyError::InvalidHandshake(
format!("Expected seq=-2, got {}", srv_seq),
));
}
let (schema, _srv_ts, srv_nonce) = parse_nonce_payload(&srv_nonce_payload)?;
if schema != RPC_CRYPTO_AES_U32 {
return Err(ProxyError::InvalidHandshake(
format!("Unsupported crypto schema: 0x{:x}", schema),
));
}
debug!(%addr, "Nonce exchange OK, deriving keys");
// ===== 3. Derive AES-256-CBC keys =====
// C buffer layout:
// [0..16] nonce_server (srv_nonce)
// [16..32] nonce_client (my_nonce)
// [32..36] client_timestamp
// [36..40] server_ip
// [40..42] client_port
// [42..48] "CLIENT" or "SERVER"
// [48..52] client_ip
// [52..54] server_port
// [54..54+N] secret (proxy-secret binary)
// [54+N..70+N] nonce_server
// nonce_client(16)
let ts_bytes = crypto_ts.to_le_bytes();
let server_ip = addr_to_ip_u32(&peer_addr);
let client_ip = addr_to_ip_u32(&local_addr);
let server_ip_bytes = server_ip.to_le_bytes();
let client_ip_bytes = client_ip.to_le_bytes();
let server_port_bytes = peer_addr.port().to_le_bytes();
let client_port_bytes = local_addr.port().to_le_bytes();
let (wk, wi) = derive_middleproxy_keys(
&srv_nonce, &my_nonce, &ts_bytes,
Some(&server_ip_bytes), &client_port_bytes,
b"CLIENT",
Some(&client_ip_bytes), &server_port_bytes,
secret, None, None,
);
let (rk, ri) = derive_middleproxy_keys(
&srv_nonce, &my_nonce, &ts_bytes,
Some(&server_ip_bytes), &client_port_bytes,
b"SERVER",
Some(&client_ip_bytes), &server_port_bytes,
secret, None, None,
);
debug!(
%addr,
write_key = %hex::encode(&wk[..8]),
read_key = %hex::encode(&rk[..8]),
"Keys derived"
);
// ===== 4. Send encrypted handshake (seq=-1) =====
let hs_payload = build_handshake_payload(
client_ip, local_addr.port(),
server_ip, peer_addr.port(),
);
let hs_frame = build_rpc_frame(-1, &hs_payload);
let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?;
wr.write_all(&encrypted_hs).await.map_err(ProxyError::Io)?;
wr.flush().await.map_err(ProxyError::Io)?;
debug!(%addr, enc_len = encrypted_hs.len(), "Sent encrypted handshake");
// ===== 5. Read encrypted handshake response (STREAMING) =====
// Server sends encrypted handshake. C crypto layer may send partial
// blocks (only complete 16-byte blocks get encrypted at a time).
// We read incrementally and decrypt block-by-block.
let deadline = Instant::now() + Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS);
let mut enc_buf = BytesMut::with_capacity(256);
let mut dec_buf = BytesMut::with_capacity(256);
let mut read_iv = ri;
let mut handshake_ok = false;
while Instant::now() < deadline && !handshake_ok {
let remaining = deadline - Instant::now();
let mut tmp = [0u8; 256];
let n = match timeout(remaining, rd.read(&mut tmp)).await {
Ok(Ok(0)) => return Err(ProxyError::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof, "ME closed during handshake",
))),
Ok(Ok(n)) => n,
Ok(Err(e)) => return Err(ProxyError::Io(e)),
Err(_) => return Err(ProxyError::TgHandshakeTimeout),
};
enc_buf.extend_from_slice(&tmp[..n]);
// Decrypt complete 16-byte blocks
let blocks = enc_buf.len() / 16 * 16;
if blocks > 0 {
let mut chunk = vec![0u8; blocks];
chunk.copy_from_slice(&enc_buf[..blocks]);
let new_iv = cbc_decrypt_inplace(&rk, &read_iv, &mut chunk)?;
read_iv = new_iv;
dec_buf.extend_from_slice(&chunk);
let _ = enc_buf.split_to(blocks);
}
// Try to parse RPC frame from decrypted data
while dec_buf.len() >= 4 {
let fl = u32::from_le_bytes([
dec_buf[0], dec_buf[1], dec_buf[2], dec_buf[3],
]) as usize;
// Skip noop padding
if fl == 4 {
let _ = dec_buf.split_to(4);
continue;
}
if fl < 12 || fl > (1 << 24) {
return Err(ProxyError::InvalidHandshake(
format!("Bad HS response frame len: {}", fl),
));
}
if dec_buf.len() < fl {
break; // need more data
}
let frame = dec_buf.split_to(fl);
// CRC32 check
let pe = fl - 4;
let ec = u32::from_le_bytes([
frame[pe], frame[pe + 1], frame[pe + 2], frame[pe + 3],
]);
let ac = crc32(&frame[..pe]);
if ec != ac {
return Err(ProxyError::InvalidHandshake(
format!("HS CRC mismatch: 0x{:08x} vs 0x{:08x}", ec, ac),
));
}
// Check type
let hs_type = u32::from_le_bytes([
frame[8], frame[9], frame[10], frame[11],
]);
if hs_type == RPC_HANDSHAKE_ERROR_U32 {
let err_code = if frame.len() >= 16 {
i32::from_le_bytes([frame[12], frame[13], frame[14], frame[15]])
} else { -1 };
return Err(ProxyError::InvalidHandshake(
format!("ME rejected handshake (error={})", err_code),
));
}
if hs_type != RPC_HANDSHAKE_U32 {
return Err(ProxyError::InvalidHandshake(
format!("Expected HANDSHAKE 0x{:08x}, got 0x{:08x}", RPC_HANDSHAKE_U32, hs_type),
));
}
handshake_ok = true;
break;
}
}
if !handshake_ok {
return Err(ProxyError::TgHandshakeTimeout);
}
info!(%addr, "RPC handshake OK");
// ===== 6. Setup writer + reader =====
let rpc_w = Arc::new(Mutex::new(RpcWriter {
writer: wr,
key: wk,
iv: write_iv,
seq_no: 0,
}));
self.writers.write().await.push(rpc_w.clone());
let reg = self.registry.clone();
let w_pong = rpc_w.clone();
let w_pool = self.writers_arc();
tokio::spawn(async move {
if let Err(e) = reader_loop(rd, rk, read_iv, reg, enc_buf, dec_buf, w_pong.clone()).await {
warn!(error = %e, "ME reader ended");
}
// Remove dead writer from pool
let mut ws = w_pool.write().await;
ws.retain(|w| !Arc::ptr_eq(w, &w_pong));
info!(remaining = ws.len(), "Dead ME writer removed from pool");
});
Ok(())
}
pub async fn send_proxy_req(
&self,
conn_id: u64,
client_addr: SocketAddr,
our_addr: SocketAddr,
data: &[u8],
proto_flags: u32,
) -> Result<()> {
let payload = build_proxy_req_payload(
conn_id, client_addr, our_addr, data,
self.proxy_tag.as_deref(), proto_flags,
);
loop {
let ws = self.writers.read().await;
if ws.is_empty() {
return Err(ProxyError::Proxy("All ME connections dead".into()));
}
let idx = self.rr.fetch_add(1, Ordering::Relaxed) as usize % ws.len();
let w = ws[idx].clone();
drop(ws);
match w.lock().await.send(&payload).await {
Ok(()) => return Ok(()),
Err(e) => {
warn!(error = %e, "ME write failed, removing dead conn");
let mut ws = self.writers.write().await;
ws.retain(|o| !Arc::ptr_eq(o, &w));
if ws.is_empty() {
return Err(ProxyError::Proxy("All ME connections dead".into()));
}
}
}
}
}
pub async fn send_close(&self, conn_id: u64) -> Result<()> {
let ws = self.writers.read().await;
if !ws.is_empty() {
let w = ws[0].clone();
drop(ws);
let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes());
p.extend_from_slice(&conn_id.to_le_bytes());
if let Err(e) = w.lock().await.send(&p).await {
debug!(error = %e, "ME close write failed");
let mut ws = self.writers.write().await;
ws.retain(|o| !Arc::ptr_eq(o, &w));
}
}
self.registry.unregister(conn_id).await;
Ok(())
}
pub fn connection_count(&self) -> usize {
self.writers.try_read().map(|w| w.len()).unwrap_or(0)
}
}
// ========== Reader Loop ==========
async fn reader_loop(
mut rd: tokio::io::ReadHalf<TcpStream>,
dk: [u8; 32],
mut div: [u8; 16],
reg: Arc<ConnRegistry>,
mut enc_leftover: BytesMut,
mut dec: BytesMut,
writer: Arc<Mutex<RpcWriter>>,
) -> Result<()> {
let mut raw = enc_leftover;
loop {
let mut tmp = [0u8; 16384];
let n = rd.read(&mut tmp).await.map_err(ProxyError::Io)?;
if n == 0 { return Ok(()); }
raw.extend_from_slice(&tmp[..n]);
// Decrypt complete 16-byte blocks
let blocks = raw.len() / 16 * 16;
if blocks > 0 {
let mut new_iv = [0u8; 16];
new_iv.copy_from_slice(&raw[blocks - 16..blocks]);
let mut chunk = vec![0u8; blocks];
chunk.copy_from_slice(&raw[..blocks]);
AesCbc::new(dk, div)
.decrypt_in_place(&mut chunk)
.map_err(|e| ProxyError::Crypto(format!("{}", e)))?;
div = new_iv;
dec.extend_from_slice(&chunk);
let _ = raw.split_to(blocks);
}
// Parse RPC frames
while dec.len() >= 12 {
let fl = u32::from_le_bytes([dec[0], dec[1], dec[2], dec[3]]) as usize;
if fl == 4 { let _ = dec.split_to(4); continue; }
if fl < 12 || fl > (1 << 24) {
warn!(frame_len = fl, "Invalid RPC frame len");
dec.clear();
break;
}
if dec.len() < fl { break; }
let frame = dec.split_to(fl);
let pe = fl - 4;
let ec = u32::from_le_bytes([frame[pe], frame[pe+1], frame[pe+2], frame[pe+3]]);
if crc32(&frame[..pe]) != ec {
warn!("CRC mismatch in data frame");
continue;
}
let payload = &frame[8..pe];
if payload.len() < 4 { continue; }
let pt = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]);
let body = &payload[4..];
if pt == RPC_PROXY_ANS_U32 && body.len() >= 12 {
let flags = u32::from_le_bytes(body[0..4].try_into().unwrap());
let cid = u64::from_le_bytes(body[4..12].try_into().unwrap());
let data = Bytes::copy_from_slice(&body[12..]);
trace!(cid, len = data.len(), flags, "ANS");
reg.route(cid, MeResponse::Data(data)).await;
} else if pt == RPC_SIMPLE_ACK_U32 && body.len() >= 12 {
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
let cfm = u32::from_le_bytes(body[8..12].try_into().unwrap());
trace!(cid, cfm, "ACK");
reg.route(cid, MeResponse::Ack(cfm)).await;
} else if pt == RPC_CLOSE_EXT_U32 && body.len() >= 8 {
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
debug!(cid, "CLOSE_EXT from ME");
reg.route(cid, MeResponse::Close).await;
reg.unregister(cid).await;
} else if pt == RPC_CLOSE_CONN_U32 && body.len() >= 8 {
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
debug!(cid, "CLOSE_CONN from ME");
reg.route(cid, MeResponse::Close).await;
reg.unregister(cid).await;
} else if pt == RPC_PING_U32 && body.len() >= 8 {
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
trace!(ping_id, "RPC_PING -> PONG");
let mut pong = Vec::with_capacity(12);
pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes());
pong.extend_from_slice(&ping_id.to_le_bytes());
if let Err(e) = writer.lock().await.send(&pong).await {
warn!(error = %e, "PONG send failed");
break;
}
} else {
debug!(rpc_type = format_args!("0x{:08x}", pt), len = body.len(), "Unknown RPC");
}
}
}
}
// ========== Proto flags ==========
/// Map ProtoTag to C-compatible RPC_PROXY_REQ transport flags.
/// C: RPC_F_COMPACT(0x40000000)=abridged, RPC_F_MEDIUM(0x20000000)=intermediate/secure
/// The 0x1000(magic) and 0x8(proxy_tag) are added inside build_proxy_req_payload.
pub fn proto_flags_for_tag(tag: crate::protocol::constants::ProtoTag) -> u32 {
use crate::protocol::constants::*;
let mut flags = RPC_FLAG_HAS_AD_TAG | RPC_FLAG_MAGIC | RPC_FLAG_EXTMODE2;
match tag {
ProtoTag::Abridged => flags | RPC_FLAG_ABRIDGED,
ProtoTag::Intermediate => flags | RPC_FLAG_INTERMEDIATE,
ProtoTag::Secure => flags | RPC_FLAG_PAD | RPC_FLAG_INTERMEDIATE,
}
}
// ========== Health Monitor (Phase 4) ==========
pub async fn me_health_monitor(
pool: Arc<MePool>,
rng: Arc<SecureRandom>,
min_connections: usize,
) {
loop {
tokio::time::sleep(Duration::from_secs(30)).await;
let current = pool.writers.read().await.len();
if current < min_connections {
warn!(current, min = min_connections, "ME pool below minimum, reconnecting...");
let addrs = TG_MIDDLE_PROXIES_FLAT_V4.clone();
for &(ip, port) in addrs.iter() {
let needed = min_connections.saturating_sub(pool.writers.read().await.len());
if needed == 0 { break; }
for _ in 0..needed {
let addr = SocketAddr::new(ip, port);
match pool.connect_one(addr, &rng).await {
Ok(()) => info!(%addr, "ME reconnected"),
Err(e) => debug!(%addr, error = %e, "ME reconnect failed"),
}
}
}
}
}
}

View File

@@ -0,0 +1,179 @@
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::crypto::{AesCbc, crc32};
use crate::error::{ProxyError, Result};
use crate::protocol::constants::*;
pub(crate) fn build_rpc_frame(seq_no: i32, payload: &[u8]) -> Vec<u8> {
let total_len = (4 + 4 + payload.len() + 4) as u32;
let mut frame = Vec::with_capacity(total_len as usize);
frame.extend_from_slice(&total_len.to_le_bytes());
frame.extend_from_slice(&seq_no.to_le_bytes());
frame.extend_from_slice(payload);
let c = crc32(&frame);
frame.extend_from_slice(&c.to_le_bytes());
frame
}
pub(crate) async fn read_rpc_frame_plaintext(
rd: &mut (impl AsyncReadExt + Unpin),
) -> Result<(i32, Vec<u8>)> {
let mut len_buf = [0u8; 4];
rd.read_exact(&mut len_buf).await.map_err(ProxyError::Io)?;
let total_len = u32::from_le_bytes(len_buf) as usize;
if !(12..=(1 << 24)).contains(&total_len) {
return Err(ProxyError::InvalidHandshake(format!(
"Bad RPC frame length: {total_len}"
)));
}
let mut rest = vec![0u8; total_len - 4];
rd.read_exact(&mut rest).await.map_err(ProxyError::Io)?;
let mut full = Vec::with_capacity(total_len);
full.extend_from_slice(&len_buf);
full.extend_from_slice(&rest);
let crc_offset = total_len - 4;
let expected_crc = u32::from_le_bytes(full[crc_offset..crc_offset + 4].try_into().unwrap());
let actual_crc = crc32(&full[..crc_offset]);
if expected_crc != actual_crc {
return Err(ProxyError::InvalidHandshake(format!(
"CRC mismatch: 0x{expected_crc:08x} vs 0x{actual_crc:08x}"
)));
}
let seq_no = i32::from_le_bytes(full[4..8].try_into().unwrap());
let payload = full[8..crc_offset].to_vec();
Ok((seq_no, payload))
}
pub(crate) fn build_nonce_payload(key_selector: u32, crypto_ts: u32, nonce: &[u8; 16]) -> [u8; 32] {
let mut p = [0u8; 32];
p[0..4].copy_from_slice(&RPC_NONCE_U32.to_le_bytes());
p[4..8].copy_from_slice(&key_selector.to_le_bytes());
p[8..12].copy_from_slice(&RPC_CRYPTO_AES_U32.to_le_bytes());
p[12..16].copy_from_slice(&crypto_ts.to_le_bytes());
p[16..32].copy_from_slice(nonce);
p
}
pub(crate) fn parse_nonce_payload(d: &[u8]) -> Result<(u32, u32, u32, [u8; 16])> {
if d.len() < 32 {
return Err(ProxyError::InvalidHandshake(format!(
"Nonce payload too short: {} bytes",
d.len()
)));
}
let t = u32::from_le_bytes(d[0..4].try_into().unwrap());
if t != RPC_NONCE_U32 {
return Err(ProxyError::InvalidHandshake(format!(
"Expected RPC_NONCE 0x{RPC_NONCE_U32:08x}, got 0x{t:08x}"
)));
}
let key_select = u32::from_le_bytes(d[4..8].try_into().unwrap());
let schema = u32::from_le_bytes(d[8..12].try_into().unwrap());
let ts = u32::from_le_bytes(d[12..16].try_into().unwrap());
let mut nonce = [0u8; 16];
nonce.copy_from_slice(&d[16..32]);
Ok((key_select, schema, ts, nonce))
}
pub(crate) fn build_handshake_payload(
our_ip: [u8; 4],
our_port: u16,
peer_ip: [u8; 4],
peer_port: u16,
) -> [u8; 32] {
let mut p = [0u8; 32];
p[0..4].copy_from_slice(&RPC_HANDSHAKE_U32.to_le_bytes());
// Keep C memory layout compatibility for PID IPv4 bytes.
p[8..12].copy_from_slice(&our_ip);
p[12..14].copy_from_slice(&our_port.to_le_bytes());
let pid = (std::process::id() & 0xffff) as u16;
p[14..16].copy_from_slice(&pid.to_le_bytes());
let utime = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as u32;
p[16..20].copy_from_slice(&utime.to_le_bytes());
p[20..24].copy_from_slice(&peer_ip);
p[24..26].copy_from_slice(&peer_port.to_le_bytes());
p
}
pub(crate) fn cbc_encrypt_padded(
key: &[u8; 32],
iv: &[u8; 16],
plaintext: &[u8],
) -> Result<(Vec<u8>, [u8; 16])> {
let pad = (16 - (plaintext.len() % 16)) % 16;
let mut buf = plaintext.to_vec();
let pad_pattern: [u8; 4] = [0x04, 0x00, 0x00, 0x00];
for i in 0..pad {
buf.push(pad_pattern[i % 4]);
}
let cipher = AesCbc::new(*key, *iv);
cipher
.encrypt_in_place(&mut buf)
.map_err(|e| ProxyError::Crypto(format!("CBC encrypt: {e}")))?;
let mut new_iv = [0u8; 16];
if buf.len() >= 16 {
new_iv.copy_from_slice(&buf[buf.len() - 16..]);
}
Ok((buf, new_iv))
}
pub(crate) fn cbc_decrypt_inplace(
key: &[u8; 32],
iv: &[u8; 16],
data: &mut [u8],
) -> Result<[u8; 16]> {
let mut new_iv = [0u8; 16];
if data.len() >= 16 {
new_iv.copy_from_slice(&data[data.len() - 16..]);
}
AesCbc::new(*key, *iv)
.decrypt_in_place(data)
.map_err(|e| ProxyError::Crypto(format!("CBC decrypt: {e}")))?;
Ok(new_iv)
}
pub(crate) struct RpcWriter {
pub(crate) writer: tokio::io::WriteHalf<tokio::net::TcpStream>,
pub(crate) key: [u8; 32],
pub(crate) iv: [u8; 16],
pub(crate) seq_no: i32,
}
impl RpcWriter {
pub(crate) async fn send(&mut self, payload: &[u8]) -> Result<()> {
let frame = build_rpc_frame(self.seq_no, payload);
self.seq_no += 1;
let pad = (16 - (frame.len() % 16)) % 16;
let mut buf = frame;
let pad_pattern: [u8; 4] = [0x04, 0x00, 0x00, 0x00];
for i in 0..pad {
buf.push(pad_pattern[i % 4]);
}
let cipher = AesCbc::new(self.key, self.iv);
cipher
.encrypt_in_place(&mut buf)
.map_err(|e| ProxyError::Crypto(format!("{e}")))?;
if buf.len() >= 16 {
self.iv.copy_from_slice(&buf[buf.len() - 16..]);
}
self.writer.write_all(&buf).await.map_err(ProxyError::Io)
}
}

View File

@@ -0,0 +1,38 @@
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, info, warn};
use crate::crypto::SecureRandom;
use crate::protocol::constants::TG_MIDDLE_PROXIES_FLAT_V4;
use super::MePool;
pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, min_connections: usize) {
loop {
tokio::time::sleep(Duration::from_secs(30)).await;
let current = pool.connection_count();
if current < min_connections {
warn!(
current,
min = min_connections,
"ME pool below minimum, reconnecting..."
);
let addrs = TG_MIDDLE_PROXIES_FLAT_V4.clone();
for &(ip, port) in addrs.iter() {
let needed = min_connections.saturating_sub(pool.connection_count());
if needed == 0 {
break;
}
for _ in 0..needed {
let addr = SocketAddr::new(ip, port);
match pool.connect_one(addr, &rng).await {
Ok(()) => info!(%addr, "ME reconnected"),
Err(e) => debug!(%addr, error = %e, "ME reconnect failed"),
}
}
}
}
}
}

View File

@@ -0,0 +1,26 @@
//! Middle Proxy RPC transport.
mod codec;
mod health;
mod pool;
mod pool_nat;
mod reader;
mod registry;
mod send;
mod secret;
mod wire;
use bytes::Bytes;
pub use health::me_health_monitor;
pub use pool::MePool;
pub use registry::ConnRegistry;
pub use secret::fetch_proxy_secret;
pub use wire::proto_flags_for_tag;
#[derive(Debug)]
pub enum MeResponse {
Data { flags: u32, data: Bytes },
Ack(u32),
Close,
}

View File

@@ -0,0 +1,499 @@
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::atomic::AtomicU64;
use std::time::Duration;
use bytes::BytesMut;
use rand::Rng;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::{Mutex, RwLock};
use tokio::time::{Instant, timeout};
use tracing::{debug, info, warn};
use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_keys, sha256};
use crate::error::{ProxyError, Result};
use crate::protocol::constants::*;
use super::ConnRegistry;
use super::codec::{
RpcWriter, build_handshake_payload, build_nonce_payload, build_rpc_frame, cbc_decrypt_inplace,
cbc_encrypt_padded, parse_nonce_payload, read_rpc_frame_plaintext,
};
use super::reader::reader_loop;
use super::wire::{IpMaterial, extract_ip_material};
const ME_ACTIVE_PING_SECS: u64 = 25;
const ME_ACTIVE_PING_JITTER_SECS: i64 = 5;
pub struct MePool {
pub(super) registry: Arc<ConnRegistry>,
pub(super) writers: Arc<RwLock<Vec<(SocketAddr, Arc<Mutex<RpcWriter>>)>>> ,
pub(super) rr: AtomicU64,
pub(super) proxy_tag: Option<Vec<u8>>,
proxy_secret: Vec<u8>,
pub(super) nat_ip_cfg: Option<IpAddr>,
pub(super) nat_ip_detected: OnceLock<IpAddr>,
pub(super) nat_probe: bool,
pub(super) nat_stun: Option<String>,
pool_size: usize,
}
impl MePool {
pub fn new(
proxy_tag: Option<Vec<u8>>,
proxy_secret: Vec<u8>,
nat_ip: Option<IpAddr>,
nat_probe: bool,
nat_stun: Option<String>,
) -> Arc<Self> {
Arc::new(Self {
registry: Arc::new(ConnRegistry::new()),
writers: Arc::new(RwLock::new(Vec::new())),
rr: AtomicU64::new(0),
proxy_tag,
proxy_secret,
nat_ip_cfg: nat_ip,
nat_ip_detected: OnceLock::new(),
nat_probe,
nat_stun,
pool_size: 2,
})
}
pub fn has_proxy_tag(&self) -> bool {
self.proxy_tag.is_some()
}
pub fn translate_our_addr(&self, addr: SocketAddr) -> SocketAddr {
let ip = self.translate_ip_for_nat(addr.ip());
SocketAddr::new(ip, addr.port())
}
pub fn registry(&self) -> &Arc<ConnRegistry> {
&self.registry
}
fn writers_arc(&self) -> Arc<RwLock<Vec<(SocketAddr, Arc<Mutex<RpcWriter>>)>>>
{
self.writers.clone()
}
fn key_selector(&self) -> u32 {
if self.proxy_secret.len() >= 4 {
u32::from_le_bytes([
self.proxy_secret[0],
self.proxy_secret[1],
self.proxy_secret[2],
self.proxy_secret[3],
])
} else {
0
}
}
pub async fn init(self: &Arc<Self>, pool_size: usize, rng: &SecureRandom) -> Result<()> {
let addrs = &*TG_MIDDLE_PROXIES_FLAT_V4;
let ks = self.key_selector();
info!(
me_servers = addrs.len(),
pool_size,
key_selector = format_args!("0x{ks:08x}"),
secret_len = self.proxy_secret.len(),
"Initializing ME pool"
);
for &(ip, port) in addrs.iter() {
for i in 0..pool_size {
let addr = SocketAddr::new(ip, port);
match self.connect_one(addr, rng).await {
Ok(()) => info!(%addr, idx = i, "ME connected"),
Err(e) => warn!(%addr, idx = i, error = %e, "ME connect failed"),
}
}
if self.writers.read().await.len() >= pool_size {
break;
}
}
if self.writers.read().await.is_empty() {
return Err(ProxyError::Proxy("No ME connections".into()));
}
Ok(())
}
pub(crate) async fn connect_one(
self: &Arc<Self>,
addr: SocketAddr,
rng: &SecureRandom,
) -> Result<()> {
let secret = &self.proxy_secret;
if secret.len() < 32 {
return Err(ProxyError::Proxy(
"proxy-secret too short for ME auth".into(),
));
}
let stream = timeout(
Duration::from_secs(ME_CONNECT_TIMEOUT_SECS),
TcpStream::connect(addr),
)
.await
.map_err(|_| ProxyError::ConnectionTimeout {
addr: addr.to_string(),
})?
.map_err(ProxyError::Io)?;
stream.set_nodelay(true).ok();
let local_addr = stream.local_addr().map_err(ProxyError::Io)?;
let peer_addr = stream.peer_addr().map_err(ProxyError::Io)?;
let _ = self.maybe_detect_nat_ip(local_addr.ip()).await;
let reflected = if self.nat_probe {
self.maybe_reflect_public_addr().await
} else {
None
};
let local_addr_nat = self.translate_our_addr_with_reflection(local_addr, reflected);
let peer_addr_nat =
SocketAddr::new(self.translate_ip_for_nat(peer_addr.ip()), peer_addr.port());
let (mut rd, mut wr) = tokio::io::split(stream);
let my_nonce: [u8; 16] = rng.bytes(16).try_into().unwrap();
let crypto_ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as u32;
let ks = self.key_selector();
let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce);
let nonce_frame = build_rpc_frame(-2, &nonce_payload);
let dump = hex_dump(&nonce_frame[..nonce_frame.len().min(44)]);
info!(
key_selector = format_args!("0x{ks:08x}"),
crypto_ts,
frame_len = nonce_frame.len(),
nonce_frame_hex = %dump,
"Sending ME nonce frame"
);
wr.write_all(&nonce_frame).await.map_err(ProxyError::Io)?;
wr.flush().await.map_err(ProxyError::Io)?;
let (srv_seq, srv_nonce_payload) = timeout(
Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS),
read_rpc_frame_plaintext(&mut rd),
)
.await
.map_err(|_| ProxyError::TgHandshakeTimeout)??;
if srv_seq != -2 {
return Err(ProxyError::InvalidHandshake(format!(
"Expected seq=-2, got {srv_seq}"
)));
}
let (srv_key_select, schema, srv_ts, srv_nonce) = parse_nonce_payload(&srv_nonce_payload)?;
if schema != RPC_CRYPTO_AES_U32 {
warn!(schema = format_args!("0x{schema:08x}"), "Unsupported ME crypto schema");
return Err(ProxyError::InvalidHandshake(format!(
"Unsupported crypto schema: 0x{schema:x}"
)));
}
if srv_key_select != ks {
return Err(ProxyError::InvalidHandshake(format!(
"Server key_select 0x{srv_key_select:08x} != client 0x{ks:08x}"
)));
}
let skew = crypto_ts.abs_diff(srv_ts);
if skew > 30 {
return Err(ProxyError::InvalidHandshake(format!(
"nonce crypto_ts skew too large: client={crypto_ts}, server={srv_ts}, skew={skew}s"
)));
}
info!(
%local_addr,
%local_addr_nat,
reflected_ip = reflected.map(|r| r.ip()).as_ref().map(ToString::to_string),
%peer_addr,
%peer_addr_nat,
key_selector = format_args!("0x{ks:08x}"),
crypto_schema = format_args!("0x{schema:08x}"),
skew_secs = skew,
"ME key derivation parameters"
);
let ts_bytes = crypto_ts.to_le_bytes();
let server_port_bytes = peer_addr_nat.port().to_le_bytes();
let client_port_bytes = local_addr_nat.port().to_le_bytes();
let server_ip = extract_ip_material(peer_addr_nat);
let client_ip = extract_ip_material(local_addr_nat);
let (srv_ip_opt, clt_ip_opt, clt_v6_opt, srv_v6_opt, hs_our_ip, hs_peer_ip) =
match (server_ip, client_ip) {
(IpMaterial::V4(srv), IpMaterial::V4(clt)) => {
(Some(srv), Some(clt), None, None, clt, srv)
}
(IpMaterial::V6(srv), IpMaterial::V6(clt)) => {
let zero = [0u8; 4];
(None, None, Some(clt), Some(srv), zero, zero)
}
_ => {
return Err(ProxyError::InvalidHandshake(
"mixed IPv4/IPv6 endpoints are not supported for ME key derivation"
.to_string(),
));
}
};
let diag_level: u8 = std::env::var("ME_DIAG")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(0);
let prekey_client = build_middleproxy_prekey(
&srv_nonce,
&my_nonce,
&ts_bytes,
srv_ip_opt.as_ref().map(|x| &x[..]),
&client_port_bytes,
b"CLIENT",
clt_ip_opt.as_ref().map(|x| &x[..]),
&server_port_bytes,
secret,
clt_v6_opt.as_ref(),
srv_v6_opt.as_ref(),
);
let prekey_server = build_middleproxy_prekey(
&srv_nonce,
&my_nonce,
&ts_bytes,
srv_ip_opt.as_ref().map(|x| &x[..]),
&client_port_bytes,
b"SERVER",
clt_ip_opt.as_ref().map(|x| &x[..]),
&server_port_bytes,
secret,
clt_v6_opt.as_ref(),
srv_v6_opt.as_ref(),
);
let (wk, wi) = derive_middleproxy_keys(
&srv_nonce,
&my_nonce,
&ts_bytes,
srv_ip_opt.as_ref().map(|x| &x[..]),
&client_port_bytes,
b"CLIENT",
clt_ip_opt.as_ref().map(|x| &x[..]),
&server_port_bytes,
secret,
clt_v6_opt.as_ref(),
srv_v6_opt.as_ref(),
);
let (rk, ri) = derive_middleproxy_keys(
&srv_nonce,
&my_nonce,
&ts_bytes,
srv_ip_opt.as_ref().map(|x| &x[..]),
&client_port_bytes,
b"SERVER",
clt_ip_opt.as_ref().map(|x| &x[..]),
&server_port_bytes,
secret,
clt_v6_opt.as_ref(),
srv_v6_opt.as_ref(),
);
let hs_payload =
build_handshake_payload(hs_our_ip, local_addr.port(), hs_peer_ip, peer_addr.port());
let hs_frame = build_rpc_frame(-1, &hs_payload);
if diag_level >= 1 {
info!(
write_key = %hex_dump(&wk),
write_iv = %hex_dump(&wi),
read_key = %hex_dump(&rk),
read_iv = %hex_dump(&ri),
srv_ip = %srv_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(),
clt_ip = %clt_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(),
srv_port = %hex_dump(&server_port_bytes),
clt_port = %hex_dump(&client_port_bytes),
crypto_ts = %hex_dump(&ts_bytes),
nonce_srv = %hex_dump(&srv_nonce),
nonce_clt = %hex_dump(&my_nonce),
prekey_sha256_client = %hex_dump(&sha256(&prekey_client)),
prekey_sha256_server = %hex_dump(&sha256(&prekey_server)),
hs_plain = %hex_dump(&hs_frame),
proxy_secret_sha256 = %hex_dump(&sha256(secret)),
"ME diag: derived keys and handshake plaintext"
);
}
if diag_level >= 2 {
info!(
prekey_client = %hex_dump(&prekey_client),
prekey_server = %hex_dump(&prekey_server),
"ME diag: full prekey buffers"
);
}
let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?;
if diag_level >= 1 {
info!(
hs_cipher = %hex_dump(&encrypted_hs),
"ME diag: handshake ciphertext"
);
}
wr.write_all(&encrypted_hs).await.map_err(ProxyError::Io)?;
wr.flush().await.map_err(ProxyError::Io)?;
let deadline = Instant::now() + Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS);
let mut enc_buf = BytesMut::with_capacity(256);
let mut dec_buf = BytesMut::with_capacity(256);
let mut read_iv = ri;
let mut handshake_ok = false;
while Instant::now() < deadline && !handshake_ok {
let remaining = deadline - Instant::now();
let mut tmp = [0u8; 256];
let n = match timeout(remaining, rd.read(&mut tmp)).await {
Ok(Ok(0)) => {
return Err(ProxyError::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"ME closed during handshake",
)));
}
Ok(Ok(n)) => n,
Ok(Err(e)) => return Err(ProxyError::Io(e)),
Err(_) => return Err(ProxyError::TgHandshakeTimeout),
};
enc_buf.extend_from_slice(&tmp[..n]);
let blocks = enc_buf.len() / 16 * 16;
if blocks > 0 {
let mut chunk = vec![0u8; blocks];
chunk.copy_from_slice(&enc_buf[..blocks]);
read_iv = cbc_decrypt_inplace(&rk, &read_iv, &mut chunk)?;
dec_buf.extend_from_slice(&chunk);
let _ = enc_buf.split_to(blocks);
}
while dec_buf.len() >= 4 {
let fl = u32::from_le_bytes(dec_buf[0..4].try_into().unwrap()) as usize;
if fl == 4 {
let _ = dec_buf.split_to(4);
continue;
}
if !(12..=(1 << 24)).contains(&fl) {
return Err(ProxyError::InvalidHandshake(format!(
"Bad HS response frame len: {fl}"
)));
}
if dec_buf.len() < fl {
break;
}
let frame = dec_buf.split_to(fl);
let pe = fl - 4;
let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap());
let ac = crate::crypto::crc32(&frame[..pe]);
if ec != ac {
return Err(ProxyError::InvalidHandshake(format!(
"HS CRC mismatch: 0x{ec:08x} vs 0x{ac:08x}"
)));
}
let hs_type = u32::from_le_bytes(frame[8..12].try_into().unwrap());
if hs_type == RPC_HANDSHAKE_ERROR_U32 {
let err_code = if frame.len() >= 16 {
i32::from_le_bytes(frame[12..16].try_into().unwrap())
} else {
-1
};
return Err(ProxyError::InvalidHandshake(format!(
"ME rejected handshake (error={err_code})"
)));
}
if hs_type != RPC_HANDSHAKE_U32 {
return Err(ProxyError::InvalidHandshake(format!(
"Expected HANDSHAKE 0x{RPC_HANDSHAKE_U32:08x}, got 0x{hs_type:08x}"
)));
}
handshake_ok = true;
break;
}
}
if !handshake_ok {
return Err(ProxyError::TgHandshakeTimeout);
}
info!(%addr, "RPC handshake OK");
let rpc_w = Arc::new(Mutex::new(RpcWriter {
writer: wr,
key: wk,
iv: write_iv,
seq_no: 0,
}));
self.writers.write().await.push((addr, rpc_w.clone()));
let reg = self.registry.clone();
let w_pong = rpc_w.clone();
let w_pool = self.writers_arc();
let w_ping = rpc_w.clone();
let w_pool_ping = self.writers_arc();
tokio::spawn(async move {
if let Err(e) =
reader_loop(rd, rk, read_iv, reg, enc_buf, dec_buf, w_pong.clone()).await
{
warn!(error = %e, "ME reader ended");
}
let mut ws = w_pool.write().await;
ws.retain(|(_, w)| !Arc::ptr_eq(w, &w_pong));
info!(remaining = ws.len(), "Dead ME writer removed from pool");
});
tokio::spawn(async move {
let mut ping_id: i64 = rand::random::<i64>();
loop {
let jitter = rand::rng()
.random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS);
let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64;
tokio::time::sleep(Duration::from_secs(wait)).await;
let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_PING_U32.to_le_bytes());
p.extend_from_slice(&ping_id.to_le_bytes());
ping_id = ping_id.wrapping_add(1);
if let Err(e) = w_ping.lock().await.send(&p).await {
debug!(error = %e, "Active ME ping failed, removing dead writer");
let mut ws = w_pool_ping.write().await;
ws.retain(|(_, w)| !Arc::ptr_eq(w, &w_ping));
break;
}
}
});
Ok(())
}
}
fn hex_dump(data: &[u8]) -> String {
const MAX: usize = 64;
let mut out = String::with_capacity(data.len() * 2 + 3);
for (i, b) in data.iter().take(MAX).enumerate() {
if i > 0 {
out.push(' ');
}
out.push_str(&format!("{b:02x}"));
}
if data.len() > MAX {
out.push_str("");
}
out
}

View File

@@ -0,0 +1,200 @@
use std::net::{IpAddr, Ipv4Addr};
use tracing::{info, warn};
use crate::error::{ProxyError, Result};
use super::MePool;
impl MePool {
pub(super) fn translate_ip_for_nat(&self, ip: IpAddr) -> IpAddr {
let nat_ip = self
.nat_ip_cfg
.or_else(|| self.nat_ip_detected.get().copied());
let Some(nat_ip) = nat_ip else {
return ip;
};
match (ip, nat_ip) {
(IpAddr::V4(src), IpAddr::V4(dst))
if is_privateish(IpAddr::V4(src))
|| src.is_loopback()
|| src.is_unspecified() =>
{
IpAddr::V4(dst)
}
(IpAddr::V6(src), IpAddr::V6(dst)) if src.is_loopback() || src.is_unspecified() => {
IpAddr::V6(dst)
}
(orig, _) => orig,
}
}
pub(super) fn translate_our_addr_with_reflection(
&self,
addr: std::net::SocketAddr,
reflected: Option<std::net::SocketAddr>,
) -> std::net::SocketAddr {
let ip = if let Some(r) = reflected {
// Use reflected IP (not port) only when local address is non-public.
if is_privateish(addr.ip()) || addr.ip().is_loopback() || addr.ip().is_unspecified() {
r.ip()
} else {
self.translate_ip_for_nat(addr.ip())
}
} else {
self.translate_ip_for_nat(addr.ip())
};
// Keep the kernel-assigned TCP source port; STUN port can differ.
std::net::SocketAddr::new(ip, addr.port())
}
pub(super) async fn maybe_detect_nat_ip(&self, local_ip: IpAddr) -> Option<IpAddr> {
if self.nat_ip_cfg.is_some() {
return self.nat_ip_cfg;
}
if !(is_privateish(local_ip) || local_ip.is_loopback() || local_ip.is_unspecified()) {
return None;
}
if let Some(ip) = self.nat_ip_detected.get().copied() {
return Some(ip);
}
match fetch_public_ipv4().await {
Ok(Some(ip)) => {
let _ = self.nat_ip_detected.set(IpAddr::V4(ip));
info!(public_ip = %ip, "Auto-detected public IP for NAT translation");
Some(IpAddr::V4(ip))
}
Ok(None) => None,
Err(e) => {
warn!(error = %e, "Failed to auto-detect public IP");
None
}
}
}
pub(super) async fn maybe_reflect_public_addr(&self) -> Option<std::net::SocketAddr> {
let stun_addr = self
.nat_stun
.clone()
.unwrap_or_else(|| "stun.l.google.com:19302".to_string());
match fetch_stun_binding(&stun_addr).await {
Ok(sa) => {
if let Some(sa) = sa {
info!(%sa, "NAT probe: reflected address");
}
sa
}
Err(e) => {
warn!(error = %e, "NAT probe failed");
None
}
}
}
}
async fn fetch_public_ipv4() -> Result<Option<Ipv4Addr>> {
let res = reqwest::get("https://checkip.amazonaws.com").await.map_err(|e| {
ProxyError::Proxy(format!("public IP detection request failed: {e}"))
})?;
let text = res.text().await.map_err(|e| {
ProxyError::Proxy(format!("public IP detection read failed: {e}"))
})?;
let ip = text.trim().parse().ok();
Ok(ip)
}
async fn fetch_stun_binding(stun_addr: &str) -> Result<Option<std::net::SocketAddr>> {
use rand::RngCore;
use tokio::net::UdpSocket;
let socket = UdpSocket::bind("0.0.0.0:0")
.await
.map_err(|e| ProxyError::Proxy(format!("STUN bind failed: {e}")))?;
socket
.connect(stun_addr)
.await
.map_err(|e| ProxyError::Proxy(format!("STUN connect failed: {e}")))?;
// Build minimal Binding Request.
let mut req = vec![0u8; 20];
req[0..2].copy_from_slice(&0x0001u16.to_be_bytes()); // Binding Request
req[2..4].copy_from_slice(&0u16.to_be_bytes()); // length
req[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes()); // magic cookie
rand::thread_rng().fill_bytes(&mut req[8..20]);
socket
.send(&req)
.await
.map_err(|e| ProxyError::Proxy(format!("STUN send failed: {e}")))?;
let mut buf = [0u8; 128];
let n = socket
.recv(&mut buf)
.await
.map_err(|e| ProxyError::Proxy(format!("STUN recv failed: {e}")))?;
if n < 20 {
return Ok(None);
}
// Parse attributes.
let mut idx = 20;
while idx + 4 <= n {
let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().unwrap());
let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().unwrap()) as usize;
idx += 4;
if idx + alen > n {
break;
}
match atype {
0x0020 /* XOR-MAPPED-ADDRESS */ | 0x0001 /* MAPPED-ADDRESS */ => {
if alen < 8 {
break;
}
let family = buf[idx + 1];
if family != 0x01 {
// only IPv4 supported here
break;
}
let port_bytes = [buf[idx + 2], buf[idx + 3]];
let ip_bytes = [buf[idx + 4], buf[idx + 5], buf[idx + 6], buf[idx + 7]];
let (port, ip) = if atype == 0x0020 {
let magic = 0x2112A442u32.to_be_bytes();
let port = u16::from_be_bytes(port_bytes) ^ ((magic[0] as u16) << 8 | magic[1] as u16);
let ip = [
ip_bytes[0] ^ magic[0],
ip_bytes[1] ^ magic[1],
ip_bytes[2] ^ magic[2],
ip_bytes[3] ^ magic[3],
];
(port, ip)
} else {
(u16::from_be_bytes(port_bytes), ip_bytes)
};
return Ok(Some(std::net::SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3])),
port,
)));
}
_ => {}
}
idx += (alen + 3) & !3; // 4-byte alignment
}
Ok(None)
}
fn is_privateish(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => v4.is_private() || v4.is_link_local(),
IpAddr::V6(v6) => v6.is_unique_local(),
}
}

View File

@@ -0,0 +1,141 @@
use std::sync::Arc;
use bytes::{Bytes, BytesMut};
use tokio::io::AsyncReadExt;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tracing::{debug, trace, warn};
use crate::crypto::{AesCbc, crc32};
use crate::error::{ProxyError, Result};
use crate::protocol::constants::*;
use super::codec::RpcWriter;
use super::{ConnRegistry, MeResponse};
pub(crate) async fn reader_loop(
mut rd: tokio::io::ReadHalf<TcpStream>,
dk: [u8; 32],
mut div: [u8; 16],
reg: Arc<ConnRegistry>,
enc_leftover: BytesMut,
mut dec: BytesMut,
writer: Arc<Mutex<RpcWriter>>,
) -> Result<()> {
let mut raw = enc_leftover;
loop {
let mut tmp = [0u8; 16_384];
let n = rd.read(&mut tmp).await.map_err(ProxyError::Io)?;
if n == 0 {
return Ok(());
}
raw.extend_from_slice(&tmp[..n]);
let blocks = raw.len() / 16 * 16;
if blocks > 0 {
let mut new_iv = [0u8; 16];
new_iv.copy_from_slice(&raw[blocks - 16..blocks]);
let mut chunk = vec![0u8; blocks];
chunk.copy_from_slice(&raw[..blocks]);
AesCbc::new(dk, div)
.decrypt_in_place(&mut chunk)
.map_err(|e| ProxyError::Crypto(format!("{e}")))?;
div = new_iv;
dec.extend_from_slice(&chunk);
let _ = raw.split_to(blocks);
}
while dec.len() >= 12 {
let fl = u32::from_le_bytes(dec[0..4].try_into().unwrap()) as usize;
if fl == 4 {
let _ = dec.split_to(4);
continue;
}
if !(12..=(1 << 24)).contains(&fl) {
warn!(frame_len = fl, "Invalid RPC frame len");
dec.clear();
break;
}
if dec.len() < fl {
break;
}
let frame = dec.split_to(fl);
let pe = fl - 4;
let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap());
if crc32(&frame[..pe]) != ec {
warn!("CRC mismatch in data frame");
continue;
}
let payload = &frame[8..pe];
if payload.len() < 4 {
continue;
}
let pt = u32::from_le_bytes(payload[0..4].try_into().unwrap());
let body = &payload[4..];
if pt == RPC_PROXY_ANS_U32 && body.len() >= 12 {
let flags = u32::from_le_bytes(body[0..4].try_into().unwrap());
let cid = u64::from_le_bytes(body[4..12].try_into().unwrap());
let data = Bytes::copy_from_slice(&body[12..]);
trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS");
let routed = reg.route(cid, MeResponse::Data { flags, data }).await;
if !routed {
reg.unregister(cid).await;
send_close_conn(&writer, cid).await;
}
} else if pt == RPC_SIMPLE_ACK_U32 && body.len() >= 12 {
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
let cfm = u32::from_le_bytes(body[8..12].try_into().unwrap());
trace!(cid, cfm, "RPC_SIMPLE_ACK");
let routed = reg.route(cid, MeResponse::Ack(cfm)).await;
if !routed {
reg.unregister(cid).await;
send_close_conn(&writer, cid).await;
}
} else if pt == RPC_CLOSE_EXT_U32 && body.len() >= 8 {
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
debug!(cid, "RPC_CLOSE_EXT from ME");
reg.route(cid, MeResponse::Close).await;
reg.unregister(cid).await;
} else if pt == RPC_CLOSE_CONN_U32 && body.len() >= 8 {
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
debug!(cid, "RPC_CLOSE_CONN from ME");
reg.route(cid, MeResponse::Close).await;
reg.unregister(cid).await;
} else if pt == RPC_PING_U32 && body.len() >= 8 {
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
trace!(ping_id, "RPC_PING -> RPC_PONG");
let mut pong = Vec::with_capacity(12);
pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes());
pong.extend_from_slice(&ping_id.to_le_bytes());
if let Err(e) = writer.lock().await.send(&pong).await {
warn!(error = %e, "PONG send failed");
break;
}
} else {
debug!(
rpc_type = format_args!("0x{pt:08x}"),
len = body.len(),
"Unknown RPC"
);
}
}
}
}
async fn send_close_conn(writer: &Arc<Mutex<RpcWriter>>, conn_id: u64) {
let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes());
p.extend_from_slice(&conn_id.to_le_bytes());
if let Err(e) = writer.lock().await.send(&p).await {
debug!(conn_id, error = %e, "Failed to send RPC_CLOSE_CONN");
}
}

View File

@@ -0,0 +1,42 @@
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::{RwLock, mpsc};
use super::MeResponse;
pub struct ConnRegistry {
map: RwLock<HashMap<u64, mpsc::Sender<MeResponse>>>,
next_id: AtomicU64,
}
impl ConnRegistry {
pub fn new() -> Self {
// Avoid fully predictable conn_id sequence from 1.
let start = rand::random::<u64>() | 1;
Self {
map: RwLock::new(HashMap::new()),
next_id: AtomicU64::new(start),
}
}
pub async fn register(&self) -> (u64, mpsc::Receiver<MeResponse>) {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = mpsc::channel(256);
self.map.write().await.insert(id, tx);
(id, rx)
}
pub async fn unregister(&self, id: u64) {
self.map.write().await.remove(&id);
}
pub async fn route(&self, id: u64, resp: MeResponse) -> bool {
let m = self.map.read().await;
if let Some(tx) = m.get(&id) {
tx.send(resp).await.is_ok()
} else {
false
}
}
}

View File

@@ -0,0 +1,81 @@
use std::time::Duration;
use tracing::{debug, info, warn};
use crate::error::{ProxyError, Result};
/// Fetch Telegram proxy-secret binary.
pub async fn fetch_proxy_secret(cache_path: Option<&str>) -> Result<Vec<u8>> {
let cache = cache_path.unwrap_or("proxy-secret");
// 1) Try fresh download first.
match download_proxy_secret().await {
Ok(data) => {
if let Err(e) = tokio::fs::write(cache, &data).await {
warn!(error = %e, "Failed to cache proxy-secret (non-fatal)");
} else {
debug!(path = cache, len = data.len(), "Cached proxy-secret");
}
return Ok(data);
}
Err(download_err) => {
warn!(error = %download_err, "Proxy-secret download failed, trying cache/file fallback");
// Fall through to cache/file.
}
}
// 2) Fallback to cache/file regardless of age; require len>=32.
match tokio::fs::read(cache).await {
Ok(data) if data.len() >= 32 => {
let age_hours = tokio::fs::metadata(cache)
.await
.ok()
.and_then(|m| m.modified().ok())
.and_then(|m| std::time::SystemTime::now().duration_since(m).ok())
.map(|d| d.as_secs() / 3600);
info!(
path = cache,
len = data.len(),
age_hours,
"Loaded proxy-secret from cache/file after download failure"
);
Ok(data)
}
Ok(data) => Err(ProxyError::Proxy(format!(
"Cached proxy-secret too short: {} bytes (need >= 32)",
data.len()
))),
Err(e) => Err(ProxyError::Proxy(format!(
"Failed to read proxy-secret cache after download failure: {e}"
))),
}
}
async fn download_proxy_secret() -> Result<Vec<u8>> {
let resp = reqwest::get("https://core.telegram.org/getProxySecret")
.await
.map_err(|e| ProxyError::Proxy(format!("Failed to download proxy-secret: {e}")))?;
if !resp.status().is_success() {
return Err(ProxyError::Proxy(format!(
"proxy-secret download HTTP {}",
resp.status()
)));
}
let data = resp
.bytes()
.await
.map_err(|e| ProxyError::Proxy(format!("Read proxy-secret body: {e}")))?
.to_vec();
if data.len() < 32 {
return Err(ProxyError::Proxy(format!(
"proxy-secret too short: {} bytes (need >= 32)",
data.len()
)));
}
info!(len = data.len(), "Downloaded proxy-secret OK");
Ok(data)
}

View File

@@ -0,0 +1,146 @@
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::Ordering;
use tokio::sync::Mutex;
use tracing::{debug, warn};
use crate::error::{ProxyError, Result};
use crate::protocol::constants::{RPC_CLOSE_EXT_U32, TG_MIDDLE_PROXIES_V4};
use super::MePool;
use super::codec::RpcWriter;
use super::wire::build_proxy_req_payload;
impl MePool {
pub async fn send_proxy_req(
&self,
conn_id: u64,
target_dc: i16,
client_addr: SocketAddr,
our_addr: SocketAddr,
data: &[u8],
proto_flags: u32,
) -> Result<()> {
let payload = build_proxy_req_payload(
conn_id,
client_addr,
our_addr,
data,
self.proxy_tag.as_deref(),
proto_flags,
);
loop {
let ws = self.writers.read().await;
if ws.is_empty() {
return Err(ProxyError::Proxy("All ME connections dead".into()));
}
let writers: Vec<(SocketAddr, Arc<Mutex<RpcWriter>>)> = ws.iter().cloned().collect();
drop(ws);
let candidate_indices = candidate_indices_for_dc(&writers, target_dc);
if candidate_indices.is_empty() {
return Err(ProxyError::Proxy("No ME writers available for target DC".into()));
}
let start = self.rr.fetch_add(1, Ordering::Relaxed) as usize % candidate_indices.len();
// Prefer immediately available writer to avoid waiting on stalled connection.
for offset in 0..candidate_indices.len() {
let cidx = (start + offset) % candidate_indices.len();
let idx = candidate_indices[cidx];
let w = writers[idx].1.clone();
if let Ok(mut guard) = w.try_lock() {
let send_res = guard.send(&payload).await;
drop(guard);
match send_res {
Ok(()) => return Ok(()),
Err(e) => {
warn!(error = %e, "ME write failed, removing dead conn");
let mut ws = self.writers.write().await;
ws.retain(|(_, o)| !Arc::ptr_eq(o, &w));
if ws.is_empty() {
return Err(ProxyError::Proxy("All ME connections dead".into()));
}
continue;
}
}
}
}
// All writers are currently busy, wait for the selected one.
let w = writers[candidate_indices[start]].1.clone();
match w.lock().await.send(&payload).await {
Ok(()) => return Ok(()),
Err(e) => {
warn!(error = %e, "ME write failed, removing dead conn");
let mut ws = self.writers.write().await;
ws.retain(|(_, o)| !Arc::ptr_eq(o, &w));
if ws.is_empty() {
return Err(ProxyError::Proxy("All ME connections dead".into()));
}
}
}
}
}
pub async fn send_close(&self, conn_id: u64) -> Result<()> {
let ws = self.writers.read().await;
if !ws.is_empty() {
let w = ws[0].1.clone();
drop(ws);
let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes());
p.extend_from_slice(&conn_id.to_le_bytes());
if let Err(e) = w.lock().await.send(&p).await {
debug!(error = %e, "ME close write failed");
let mut ws = self.writers.write().await;
ws.retain(|(_, o)| !Arc::ptr_eq(o, &w));
}
}
self.registry.unregister(conn_id).await;
Ok(())
}
pub fn connection_count(&self) -> usize {
self.writers.try_read().map(|w| w.len()).unwrap_or(0)
}
}
fn candidate_indices_for_dc(
writers: &[(SocketAddr, Arc<Mutex<RpcWriter>>)],
target_dc: i16,
) -> Vec<usize> {
let mut preferred = Vec::<SocketAddr>::new();
let key = target_dc as i32;
if let Some(v) = TG_MIDDLE_PROXIES_V4.get(&key) {
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port)));
}
if preferred.is_empty() {
let abs = key.abs();
if let Some(v) = TG_MIDDLE_PROXIES_V4.get(&abs) {
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port)));
}
}
if preferred.is_empty() {
let abs = key.abs();
if let Some(v) = TG_MIDDLE_PROXIES_V4.get(&-abs) {
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port)));
}
}
if preferred.is_empty() {
return (0..writers.len()).collect();
}
let mut out = Vec::new();
for (idx, (addr, _)) in writers.iter().enumerate() {
if preferred.iter().any(|p| p == addr) {
out.push(idx);
}
}
if out.is_empty() {
return (0..writers.len()).collect();
}
out
}

View File

@@ -0,0 +1,106 @@
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use crate::protocol::constants::*;
#[derive(Clone, Copy)]
pub(crate) enum IpMaterial {
V4([u8; 4]),
V6([u8; 16]),
}
pub(crate) fn extract_ip_material(addr: SocketAddr) -> IpMaterial {
match addr.ip() {
IpAddr::V4(v4) => IpMaterial::V4(v4.octets()),
IpAddr::V6(v6) => {
if let Some(v4) = v6.to_ipv4_mapped() {
IpMaterial::V4(v4.octets())
} else {
IpMaterial::V6(v6.octets())
}
}
}
}
fn ipv4_to_mapped_v6_c_compat(ip: Ipv4Addr) -> [u8; 16] {
let mut buf = [0u8; 16];
// Matches tl_store_long(0) + tl_store_int(-0x10000).
buf[8..12].copy_from_slice(&(-0x10000i32).to_le_bytes());
// Matches tl_store_int(htonl(remote_ip_host_order)).
let host_order = u32::from_ne_bytes(ip.octets());
let network_order = host_order.to_be();
buf[12..16].copy_from_slice(&network_order.to_le_bytes());
buf
}
fn append_mapped_addr_and_port(buf: &mut Vec<u8>, addr: SocketAddr) {
match addr.ip() {
IpAddr::V4(v4) => buf.extend_from_slice(&ipv4_to_mapped_v6_c_compat(v4)),
IpAddr::V6(v6) => buf.extend_from_slice(&v6.octets()),
}
buf.extend_from_slice(&(addr.port() as u32).to_le_bytes());
}
pub(crate) fn build_proxy_req_payload(
conn_id: u64,
client_addr: SocketAddr,
our_addr: SocketAddr,
data: &[u8],
proxy_tag: Option<&[u8]>,
proto_flags: u32,
) -> Vec<u8> {
let mut b = Vec::with_capacity(128 + data.len());
b.extend_from_slice(&RPC_PROXY_REQ_U32.to_le_bytes());
b.extend_from_slice(&proto_flags.to_le_bytes());
b.extend_from_slice(&conn_id.to_le_bytes());
append_mapped_addr_and_port(&mut b, client_addr);
append_mapped_addr_and_port(&mut b, our_addr);
if proto_flags & 12 != 0 {
let extra_start = b.len();
b.extend_from_slice(&0u32.to_le_bytes());
if let Some(tag) = proxy_tag {
b.extend_from_slice(&TL_PROXY_TAG_U32.to_le_bytes());
if tag.len() < 254 {
b.push(tag.len() as u8);
b.extend_from_slice(tag);
let pad = (4 - ((1 + tag.len()) % 4)) % 4;
b.extend(std::iter::repeat_n(0u8, pad));
} else {
b.push(0xfe);
let len_bytes = (tag.len() as u32).to_le_bytes();
b.extend_from_slice(&len_bytes[..3]);
b.extend_from_slice(tag);
let pad = (4 - (tag.len() % 4)) % 4;
b.extend(std::iter::repeat_n(0u8, pad));
}
}
let extra_bytes = (b.len() - extra_start - 4) as u32;
b[extra_start..extra_start + 4].copy_from_slice(&extra_bytes.to_le_bytes());
}
b.extend_from_slice(data);
b
}
pub fn proto_flags_for_tag(tag: crate::protocol::constants::ProtoTag, has_proxy_tag: bool) -> u32 {
use crate::protocol::constants::ProtoTag;
let mut flags = RPC_FLAG_MAGIC | RPC_FLAG_EXTMODE2;
if has_proxy_tag {
flags |= RPC_FLAG_HAS_AD_TAG;
}
match tag {
ProtoTag::Abridged => flags | RPC_FLAG_ABRIDGED,
ProtoTag::Intermediate => flags | RPC_FLAG_INTERMEDIATE,
ProtoTag::Secure => flags | RPC_FLAG_PAD | RPC_FLAG_INTERMEDIATE,
}
}

View File

@@ -10,4 +10,5 @@ pub use pool::ConnectionPool;
pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol}; pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol};
pub use socket::*; pub use socket::*;
pub use socks::*; pub use socks::*;
pub use upstream::{UpstreamManager, StartupPingResult, DcPingResult}; pub use upstream::{DcPingResult, StartupPingResult, UpstreamManager};
pub mod middle_proxy;

View File

@@ -1,4 +1,6 @@
//! Upstream Management with per-DC latency-weighted selection //! Upstream Management with per-DC latency-weighted selection
//!
//! IPv6/IPv4 connectivity checks with configurable preference.
use std::net::{SocketAddr, IpAddr}; use std::net::{SocketAddr, IpAddr};
use std::sync::Arc; use std::sync::Arc;
@@ -18,6 +20,9 @@ use crate::transport::socks::{connect_socks4, connect_socks5};
/// Number of Telegram datacenters /// Number of Telegram datacenters
const NUM_DCS: usize = 5; const NUM_DCS: usize = 5;
/// Timeout for individual DC ping attempt
const DC_PING_TIMEOUT_SECS: u64 = 5;
// ============= RTT Tracking ============= // ============= RTT Tracking =============
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
@@ -43,6 +48,29 @@ impl LatencyEma {
} }
} }
// ============= Per-DC IP Preference Tracking =============
/// Tracks which IP version works for each DC
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IpPreference {
/// Not yet tested
Unknown,
/// IPv6 works
PreferV6,
/// Only IPv4 works (IPv6 failed)
PreferV4,
/// Both work
BothWork,
/// Both failed
Unavailable,
}
impl Default for IpPreference {
fn default() -> Self {
Self::Unknown
}
}
// ============= Upstream State ============= // ============= Upstream State =============
#[derive(Debug)] #[derive(Debug)]
@@ -53,6 +81,8 @@ struct UpstreamState {
last_check: std::time::Instant, last_check: std::time::Instant,
/// Per-DC latency EMA (index 0 = DC1, index 4 = DC5) /// Per-DC latency EMA (index 0 = DC1, index 4 = DC5)
dc_latency: [LatencyEma; NUM_DCS], dc_latency: [LatencyEma; NUM_DCS],
/// Per-DC IP version preference (learned from connectivity tests)
dc_ip_pref: [IpPreference; NUM_DCS],
} }
impl UpstreamState { impl UpstreamState {
@@ -63,16 +93,11 @@ impl UpstreamState {
fails: 0, fails: 0,
last_check: std::time::Instant::now(), last_check: std::time::Instant::now(),
dc_latency: [LatencyEma::new(0.3); NUM_DCS], dc_latency: [LatencyEma::new(0.3); NUM_DCS],
dc_ip_pref: [IpPreference::Unknown; NUM_DCS],
} }
} }
/// Map DC index to latency array slot (0..NUM_DCS). /// Map DC index to latency array slot (0..NUM_DCS).
///
/// Matches the C implementation's `mf_cluster_lookup` behavior:
/// - Standard DCs ±1..±5 → direct mapping to array index 0..4
/// - Unknown DCs (CDN, media, etc.) → default DC slot (index 1 = DC 2)
/// This matches Telegram's `default 2;` in proxy-multi.conf.
/// - There is NO modular arithmetic in the C implementation.
fn dc_array_idx(dc_idx: i16) -> Option<usize> { fn dc_array_idx(dc_idx: i16) -> Option<usize> {
let abs_dc = dc_idx.unsigned_abs() as usize; let abs_dc = dc_idx.unsigned_abs() as usize;
if abs_dc == 0 { if abs_dc == 0 {
@@ -82,21 +107,18 @@ impl UpstreamState {
Some(abs_dc - 1) Some(abs_dc - 1)
} else { } else {
// Unknown DC → default cluster (DC 2, index 1) // Unknown DC → default cluster (DC 2, index 1)
// Same as C: mf_cluster_lookup returns default_cluster
Some(1) Some(1)
} }
} }
/// Get latency for a specific DC, falling back to average across all known DCs /// Get latency for a specific DC, falling back to average across all known DCs
fn effective_latency(&self, dc_idx: Option<i16>) -> Option<f64> { fn effective_latency(&self, dc_idx: Option<i16>) -> Option<f64> {
// Try DC-specific latency first
if let Some(di) = dc_idx.and_then(Self::dc_array_idx) { if let Some(di) = dc_idx.and_then(Self::dc_array_idx) {
if let Some(ms) = self.dc_latency[di].get() { if let Some(ms) = self.dc_latency[di].get() {
return Some(ms); return Some(ms);
} }
} }
// Fallback: average of all known DC latencies
let (sum, count) = self.dc_latency.iter() let (sum, count) = self.dc_latency.iter()
.filter_map(|l| l.get()) .filter_map(|l| l.get())
.fold((0.0, 0u32), |(s, c), v| (s + v, c + 1)); .fold((0.0, 0u32), |(s, c), v| (s + v, c + 1));
@@ -114,11 +136,14 @@ pub struct DcPingResult {
pub error: Option<String>, pub error: Option<String>,
} }
/// Result of startup ping for one upstream /// Result of startup ping for one upstream (separate v6/v4 results)
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct StartupPingResult { pub struct StartupPingResult {
pub results: Vec<DcPingResult>, pub v6_results: Vec<DcPingResult>,
pub v4_results: Vec<DcPingResult>,
pub upstream_name: String, pub upstream_name: String,
/// True if both IPv6 and IPv4 have at least one working DC
pub both_available: bool,
} }
// ============= Upstream Manager ============= // ============= Upstream Manager =============
@@ -141,15 +166,6 @@ impl UpstreamManager {
} }
/// Select upstream using latency-weighted random selection. /// Select upstream using latency-weighted random selection.
///
/// `effective_weight = config_weight × latency_factor`
///
/// where `latency_factor = 1000 / latency_ms` if latency is known,
/// or `1.0` if no latency data is available.
///
/// This means a 50ms upstream gets factor 20, a 200ms upstream gets
/// factor 5 — the faster route is 4× more likely to be chosen
/// (all else being equal).
async fn select_upstream(&self, dc_idx: Option<i16>) -> Option<usize> { async fn select_upstream(&self, dc_idx: Option<i16>) -> Option<usize> {
let upstreams = self.upstreams.read().await; let upstreams = self.upstreams.read().await;
if upstreams.is_empty() { if upstreams.is_empty() {
@@ -163,7 +179,6 @@ impl UpstreamManager {
.collect(); .collect();
if healthy.is_empty() { if healthy.is_empty() {
// All unhealthy — pick any
return Some(rand::rng().gen_range(0..upstreams.len())); return Some(rand::rng().gen_range(0..upstreams.len()));
} }
@@ -171,7 +186,6 @@ impl UpstreamManager {
return Some(healthy[0]); return Some(healthy[0]);
} }
// Calculate latency-weighted scores
let weights: Vec<(usize, f64)> = healthy.iter().map(|&i| { let weights: Vec<(usize, f64)> = healthy.iter().map(|&i| {
let base = upstreams[i].config.weight as f64; let base = upstreams[i].config.weight as f64;
let latency_factor = upstreams[i].effective_latency(dc_idx) let latency_factor = upstreams[i].effective_latency(dc_idx)
@@ -207,9 +221,6 @@ impl UpstreamManager {
} }
/// Connect to target through a selected upstream. /// Connect to target through a selected upstream.
///
/// `dc_idx` is used for latency-based upstream selection and RTT tracking.
/// Pass `None` if DC index is unknown.
pub async fn connect(&self, target: SocketAddr, dc_idx: Option<i16>) -> Result<TcpStream> { pub async fn connect(&self, target: SocketAddr, dc_idx: Option<i16>) -> Result<TcpStream> {
let idx = self.select_upstream(dc_idx).await let idx = self.select_upstream(dc_idx).await
.ok_or_else(|| ProxyError::Config("No upstreams available".to_string()))?; .ok_or_else(|| ProxyError::Config("No upstreams available".to_string()))?;
@@ -232,7 +243,6 @@ impl UpstreamManager {
u.healthy = true; u.healthy = true;
u.fails = 0; u.fails = 0;
// Store per-DC latency
if let Some(di) = dc_idx.and_then(UpstreamState::dc_array_idx) { if let Some(di) = dc_idx.and_then(UpstreamState::dc_array_idx) {
u.dc_latency[di].update(rtt_ms); u.dc_latency[di].update(rtt_ms);
} }
@@ -336,9 +346,10 @@ impl UpstreamManager {
} }
} }
// ============= Startup Ping ============= // ============= Startup Ping (test both IPv6 and IPv4) =============
/// Ping all Telegram DCs through all upstreams. /// Ping all Telegram DCs through all upstreams.
/// Tests BOTH IPv6 and IPv4, returns separate results for each.
pub async fn ping_all_dcs(&self, prefer_ipv6: bool) -> Vec<StartupPingResult> { pub async fn ping_all_dcs(&self, prefer_ipv6: bool) -> Vec<StartupPingResult> {
let upstreams: Vec<(usize, UpstreamConfig)> = { let upstreams: Vec<(usize, UpstreamConfig)> = {
let guard = self.upstreams.read().await; let guard = self.upstreams.read().await;
@@ -347,8 +358,6 @@ impl UpstreamManager {
.collect() .collect()
}; };
let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 };
let mut all_results = Vec::new(); let mut all_results = Vec::new();
for (upstream_idx, upstream_config) in &upstreams { for (upstream_idx, upstream_config) in &upstreams {
@@ -360,50 +369,115 @@ impl UpstreamManager {
UpstreamType::Socks5 { address, .. } => format!("socks5://{}", address), UpstreamType::Socks5 { address, .. } => format!("socks5://{}", address),
}; };
let mut dc_results = Vec::new(); let mut v6_results = Vec::new();
let mut v4_results = Vec::new();
for (dc_zero_idx, dc_ip) in datacenters.iter().enumerate() { // === Ping IPv6 first ===
let dc_addr = SocketAddr::new(*dc_ip, TG_DATACENTER_PORT); for dc_zero_idx in 0..NUM_DCS {
let dc_v6 = TG_DATACENTERS_V6[dc_zero_idx];
let addr_v6 = SocketAddr::new(dc_v6, TG_DATACENTER_PORT);
let ping_result = tokio::time::timeout( let result = tokio::time::timeout(
Duration::from_secs(5), Duration::from_secs(DC_PING_TIMEOUT_SECS),
self.ping_single_dc(upstream_config, dc_addr) self.ping_single_dc(&upstream_config, addr_v6)
).await; ).await;
let result = match ping_result { let ping_result = match result {
Ok(Ok(rtt_ms)) => { Ok(Ok(rtt_ms)) => {
// Store per-DC latency
let mut guard = self.upstreams.write().await; let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(*upstream_idx) { if let Some(u) = guard.get_mut(*upstream_idx) {
u.dc_latency[dc_zero_idx].update(rtt_ms); u.dc_latency[dc_zero_idx].update(rtt_ms);
} }
DcPingResult { DcPingResult {
dc_idx: dc_zero_idx + 1, dc_idx: dc_zero_idx + 1,
dc_addr, dc_addr: addr_v6,
rtt_ms: Some(rtt_ms), rtt_ms: Some(rtt_ms),
error: None, error: None,
} }
} }
Ok(Err(e)) => DcPingResult { Ok(Err(e)) => DcPingResult {
dc_idx: dc_zero_idx + 1, dc_idx: dc_zero_idx + 1,
dc_addr, dc_addr: addr_v6,
rtt_ms: None, rtt_ms: None,
error: Some(e.to_string()), error: Some(e.to_string()),
}, },
Err(_) => DcPingResult { Err(_) => DcPingResult {
dc_idx: dc_zero_idx + 1, dc_idx: dc_zero_idx + 1,
dc_addr, dc_addr: addr_v6,
rtt_ms: None, rtt_ms: None,
error: Some("timeout (5s)".to_string()), error: Some("timeout".to_string()),
}, },
}; };
v6_results.push(ping_result);
}
dc_results.push(result); // === Then ping IPv4 ===
for dc_zero_idx in 0..NUM_DCS {
let dc_v4 = TG_DATACENTERS_V4[dc_zero_idx];
let addr_v4 = SocketAddr::new(dc_v4, TG_DATACENTER_PORT);
let result = tokio::time::timeout(
Duration::from_secs(DC_PING_TIMEOUT_SECS),
self.ping_single_dc(&upstream_config, addr_v4)
).await;
let ping_result = match result {
Ok(Ok(rtt_ms)) => {
let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(*upstream_idx) {
u.dc_latency[dc_zero_idx].update(rtt_ms);
}
DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v4,
rtt_ms: Some(rtt_ms),
error: None,
}
}
Ok(Err(e)) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v4,
rtt_ms: None,
error: Some(e.to_string()),
},
Err(_) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v4,
rtt_ms: None,
error: Some("timeout".to_string()),
},
};
v4_results.push(ping_result);
}
// Check if both IP versions have at least one working DC
let v6_has_working = v6_results.iter().any(|r| r.rtt_ms.is_some());
let v4_has_working = v4_results.iter().any(|r| r.rtt_ms.is_some());
let both_available = v6_has_working && v4_has_working;
// Update IP preference for each DC
{
let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(*upstream_idx) {
for dc_zero_idx in 0..NUM_DCS {
let v6_ok = v6_results[dc_zero_idx].rtt_ms.is_some();
let v4_ok = v4_results[dc_zero_idx].rtt_ms.is_some();
u.dc_ip_pref[dc_zero_idx] = match (v6_ok, v4_ok) {
(true, true) => IpPreference::BothWork,
(true, false) => IpPreference::PreferV6,
(false, true) => IpPreference::PreferV4,
(false, false) => IpPreference::Unavailable,
};
}
}
} }
all_results.push(StartupPingResult { all_results.push(StartupPingResult {
results: dc_results, v6_results,
v4_results,
upstream_name, upstream_name,
both_available,
}); });
} }
@@ -419,19 +493,30 @@ impl UpstreamManager {
// ============= Health Checks ============= // ============= Health Checks =============
/// Background health check: rotates through DCs, 30s interval. /// Background health check: rotates through DCs, 30s interval.
/// Uses preferred IP version based on config.
pub async fn run_health_checks(&self, prefer_ipv6: bool) { pub async fn run_health_checks(&self, prefer_ipv6: bool) {
let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 };
let mut dc_rotation = 0usize; let mut dc_rotation = 0usize;
loop { loop {
tokio::time::sleep(Duration::from_secs(30)).await; tokio::time::sleep(Duration::from_secs(30)).await;
let dc_zero_idx = dc_rotation % datacenters.len(); let dc_zero_idx = dc_rotation % NUM_DCS;
dc_rotation += 1; dc_rotation += 1;
let check_target = SocketAddr::new(datacenters[dc_zero_idx], TG_DATACENTER_PORT); let dc_addr = if prefer_ipv6 {
SocketAddr::new(TG_DATACENTERS_V6[dc_zero_idx], TG_DATACENTER_PORT)
} else {
SocketAddr::new(TG_DATACENTERS_V4[dc_zero_idx], TG_DATACENTER_PORT)
};
let fallback_addr = if prefer_ipv6 {
SocketAddr::new(TG_DATACENTERS_V4[dc_zero_idx], TG_DATACENTER_PORT)
} else {
SocketAddr::new(TG_DATACENTERS_V6[dc_zero_idx], TG_DATACENTER_PORT)
};
let count = self.upstreams.read().await.len(); let count = self.upstreams.read().await.len();
for i in 0..count { for i in 0..count {
let config = { let config = {
let guard = self.upstreams.read().await; let guard = self.upstreams.read().await;
@@ -441,48 +526,102 @@ impl UpstreamManager {
let start = Instant::now(); let start = Instant::now();
let result = tokio::time::timeout( let result = tokio::time::timeout(
Duration::from_secs(10), Duration::from_secs(10),
self.connect_via_upstream(&config, check_target) self.connect_via_upstream(&config, dc_addr)
).await; ).await;
let mut guard = self.upstreams.write().await;
let u = &mut guard[i];
match result { match result {
Ok(Ok(_stream)) => { Ok(Ok(_stream)) => {
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0; let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
let mut guard = self.upstreams.write().await;
let u = &mut guard[i];
u.dc_latency[dc_zero_idx].update(rtt_ms); u.dc_latency[dc_zero_idx].update(rtt_ms);
if !u.healthy { if !u.healthy {
info!( info!(
rtt = format!("{:.0}ms", rtt_ms), rtt = format!("{:.0} ms", rtt_ms),
dc = dc_zero_idx + 1, dc = dc_zero_idx + 1,
"Upstream recovered" "Upstream recovered"
); );
} }
u.healthy = true; u.healthy = true;
u.fails = 0; u.fails = 0;
u.last_check = std::time::Instant::now();
} }
Ok(Err(e)) => { Ok(Err(_)) | Err(_) => {
u.fails += 1; // Try fallback
debug!(dc = dc_zero_idx + 1, fails = u.fails, debug!(dc = dc_zero_idx + 1, "Health check failed, trying fallback");
"Health check failed: {}", e);
if u.fails > 3 { let start2 = Instant::now();
u.healthy = false; let result2 = tokio::time::timeout(
warn!("Upstream unhealthy (fails)"); Duration::from_secs(10),
} self.connect_via_upstream(&config, fallback_addr)
} ).await;
Err(_) => {
u.fails += 1; let mut guard = self.upstreams.write().await;
debug!(dc = dc_zero_idx + 1, fails = u.fails, let u = &mut guard[i];
"Health check timeout");
if u.fails > 3 { match result2 {
u.healthy = false; Ok(Ok(_stream)) => {
warn!("Upstream unhealthy (timeout)"); let rtt_ms = start2.elapsed().as_secs_f64() * 1000.0;
u.dc_latency[dc_zero_idx].update(rtt_ms);
if !u.healthy {
info!(
rtt = format!("{:.0} ms", rtt_ms),
dc = dc_zero_idx + 1,
"Upstream recovered (fallback)"
);
}
u.healthy = true;
u.fails = 0;
}
Ok(Err(e)) => {
u.fails += 1;
debug!(dc = dc_zero_idx + 1, fails = u.fails,
"Health check failed (both): {}", e);
if u.fails > 3 {
u.healthy = false;
warn!("Upstream unhealthy (fails)");
}
}
Err(_) => {
u.fails += 1;
debug!(dc = dc_zero_idx + 1, fails = u.fails,
"Health check timeout (both)");
if u.fails > 3 {
u.healthy = false;
warn!("Upstream unhealthy (timeout)");
}
}
} }
u.last_check = std::time::Instant::now();
} }
} }
u.last_check = std::time::Instant::now();
} }
} }
} }
/// Get the preferred IP for a DC (for use by other components)
pub async fn get_dc_ip_preference(&self, dc_idx: i16) -> Option<IpPreference> {
let guard = self.upstreams.read().await;
if guard.is_empty() {
return None;
}
UpstreamState::dc_array_idx(dc_idx)
.map(|idx| guard[0].dc_ip_pref[idx])
}
/// Get preferred DC address based on config preference
pub async fn get_dc_addr(&self, dc_idx: i16, prefer_ipv6: bool) -> Option<SocketAddr> {
let arr_idx = UpstreamState::dc_array_idx(dc_idx)?;
let ip = if prefer_ipv6 {
TG_DATACENTERS_V6[arr_idx]
} else {
TG_DATACENTERS_V4[arr_idx]
};
Some(SocketAddr::new(ip, TG_DATACENTER_PORT))
}
} }