From 70859aa5cf1e650cf902e49cdc132a3f83f15d02 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Sat, 14 Feb 2026 01:36:14 +0300 Subject: [PATCH] Middle Proxy is so real --- src/config/mod.rs | 170 +++-- src/main.rs | 266 ++++--- src/proxy/client.rs | 931 +++++++++---------------- src/proxy/direct_relay.rs | 163 +++++ src/proxy/middle_relay.rs | 247 +++++++ src/transport/middle_proxy/codec.rs | 178 +++++ src/transport/middle_proxy/health.rs | 38 + src/transport/middle_proxy/mod.rs | 24 + src/transport/middle_proxy/pool.rs | 431 ++++++++++++ src/transport/middle_proxy/reader.rs | 141 ++++ src/transport/middle_proxy/registry.rs | 40 ++ src/transport/middle_proxy/secret.rs | 76 ++ src/transport/middle_proxy/wire.rs | 106 +++ src/transport/mod.rs | 2 +- 14 files changed, 2028 insertions(+), 785 deletions(-) create mode 100644 src/proxy/direct_relay.rs create mode 100644 src/proxy/middle_relay.rs create mode 100644 src/transport/middle_proxy/codec.rs create mode 100644 src/transport/middle_proxy/health.rs create mode 100644 src/transport/middle_proxy/mod.rs create mode 100644 src/transport/middle_proxy/pool.rs create mode 100644 src/transport/middle_proxy/reader.rs create mode 100644 src/transport/middle_proxy/registry.rs create mode 100644 src/transport/middle_proxy/secret.rs create mode 100644 src/transport/middle_proxy/wire.rs diff --git a/src/config/mod.rs b/src/config/mod.rs index 7f74489..dbf8afa 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,32 +1,55 @@ //! Configuration +use crate::error::{ProxyError, Result}; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::net::{IpAddr, SocketAddr}; use std::path::Path; -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; -use crate::error::{ProxyError, Result}; // ============= Helper Defaults ============= -fn default_true() -> bool { true } -fn default_port() -> u16 { 443 } -fn default_tls_domain() -> String { "www.google.com".to_string() } -fn default_mask_port() -> u16 { 443 } -fn default_replay_check_len() -> usize { 65536 } -fn default_replay_window_secs() -> u64 { 1800 } -fn default_handshake_timeout() -> u64 { 15 } -fn default_connect_timeout() -> u64 { 10 } -fn default_keepalive() -> u64 { 60 } -fn default_ack_timeout() -> u64 { 300 } -fn default_listen_addr() -> String { "0.0.0.0".to_string() } -fn default_fake_cert_len() -> usize { 2048 } -fn default_weight() -> u16 { 1 } +fn default_true() -> bool { + true +} +fn default_port() -> u16 { + 443 +} +fn default_tls_domain() -> String { + "www.google.com".to_string() +} +fn default_mask_port() -> u16 { + 443 +} +fn default_replay_check_len() -> usize { + 65536 +} +fn default_replay_window_secs() -> u64 { + 1800 +} +fn default_handshake_timeout() -> u64 { + 15 +} +fn default_connect_timeout() -> u64 { + 10 +} +fn default_keepalive() -> u64 { + 60 +} +fn default_ack_timeout() -> u64 { + 300 +} +fn default_listen_addr() -> String { + "0.0.0.0".to_string() +} +fn default_fake_cert_len() -> usize { + 2048 +} +fn default_weight() -> u16 { + 1 +} fn default_metrics_whitelist() -> Vec { - vec![ - "127.0.0.1".parse().unwrap(), - "::1".parse().unwrap(), - ] + vec!["127.0.0.1".parse().unwrap(), "::1".parse().unwrap()] } // ============= Log Level ============= @@ -96,7 +119,11 @@ pub struct ProxyModes { impl Default for ProxyModes { fn default() -> Self { - Self { classic: true, secure: true, tls: true } + Self { + classic: true, + secure: true, + tls: true, + } } } @@ -104,13 +131,13 @@ impl Default for ProxyModes { pub struct GeneralConfig { #[serde(default)] pub modes: ProxyModes, - + #[serde(default)] pub prefer_ipv6: bool, - + #[serde(default = "default_true")] pub fast_mode: bool, - + #[serde(default)] pub use_middle_proxy: bool, @@ -121,7 +148,12 @@ pub struct GeneralConfig { /// Infrastructure secret from https://core.telegram.org/getProxySecret #[serde(default)] pub proxy_secret_path: Option, - + + /// Public IP override for middle-proxy NAT environments. + /// When set, this IP is used in ME key derivation and RPC_PROXY_REQ "our_addr". + #[serde(default)] + pub middle_proxy_nat_ip: Option, + #[serde(default)] pub log_level: LogLevel, } @@ -135,6 +167,7 @@ impl Default for GeneralConfig { use_middle_proxy: false, ad_tag: None, proxy_secret_path: None, + middle_proxy_nat_ip: None, log_level: LogLevel::Normal, } } @@ -147,16 +180,16 @@ pub struct ServerConfig { #[serde(default = "default_listen_addr")] pub listen_addr_ipv4: String, - + #[serde(default)] pub listen_addr_ipv6: Option, - + #[serde(default)] pub listen_unix_sock: Option, - + #[serde(default)] pub metrics_port: Option, - + #[serde(default = "default_metrics_whitelist")] pub metrics_whitelist: Vec, @@ -182,13 +215,13 @@ impl Default for ServerConfig { pub struct TimeoutsConfig { #[serde(default = "default_handshake_timeout")] pub client_handshake: u64, - + #[serde(default = "default_connect_timeout")] pub tg_connect: u64, - + #[serde(default = "default_keepalive")] pub client_keepalive: u64, - + #[serde(default = "default_ack_timeout")] pub client_ack: u64, } @@ -208,13 +241,13 @@ impl Default for TimeoutsConfig { pub struct AntiCensorshipConfig { #[serde(default = "default_tls_domain")] pub tls_domain: String, - + #[serde(default = "default_true")] pub mask: bool, - + #[serde(default)] pub mask_host: Option, - + #[serde(default = "default_mask_port")] pub mask_port: u16, @@ -245,19 +278,19 @@ pub struct AccessConfig { #[serde(default)] pub user_max_tcp_conns: HashMap, - + #[serde(default)] pub user_expirations: HashMap>, - + #[serde(default)] pub user_data_quota: HashMap, #[serde(default = "default_replay_check_len")] pub replay_check_len: usize, - + #[serde(default = "default_replay_window_secs")] pub replay_window_secs: u64, - + #[serde(default)] pub ignore_time_skew: bool, } @@ -265,7 +298,10 @@ pub struct AccessConfig { impl Default for AccessConfig { fn default() -> Self { let mut users = HashMap::new(); - users.insert("default".to_string(), "00000000000000000000000000000000".to_string()); + users.insert( + "default".to_string(), + "00000000000000000000000000000000".to_string(), + ); Self { users, user_max_tcp_conns: HashMap::new(), @@ -365,12 +401,12 @@ pub struct ProxyConfig { impl ProxyConfig { pub fn load>(path: P) -> Result { - let content = std::fs::read_to_string(path) - .map_err(|e| ProxyError::Config(e.to_string()))?; - - let mut config: ProxyConfig = toml::from_str(&content) - .map_err(|e| ProxyError::Config(e.to_string()))?; - + let content = + std::fs::read_to_string(path).map_err(|e| ProxyError::Config(e.to_string()))?; + + let mut config: ProxyConfig = + toml::from_str(&content).map_err(|e| ProxyError::Config(e.to_string()))?; + // Validate secrets for (user, secret) in &config.access.users { if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 { @@ -380,33 +416,34 @@ impl ProxyConfig { }); } } - + // Validate tls_domain if config.censorship.tls_domain.is_empty() { return Err(ProxyError::Config("tls_domain cannot be empty".to_string())); } - + // Validate mask_unix_sock if let Some(ref sock_path) = config.censorship.mask_unix_sock { if sock_path.is_empty() { return Err(ProxyError::Config( - "mask_unix_sock cannot be empty".to_string() + "mask_unix_sock cannot be empty".to_string(), )); } #[cfg(unix)] if sock_path.len() > 107 { - return Err(ProxyError::Config( - format!("mask_unix_sock path too long: {} bytes (max 107)", sock_path.len()) - )); + return Err(ProxyError::Config(format!( + "mask_unix_sock path too long: {} bytes (max 107)", + sock_path.len() + ))); } #[cfg(not(unix))] return Err(ProxyError::Config( - "mask_unix_sock is only supported on Unix platforms".to_string() + "mask_unix_sock is only supported on Unix platforms".to_string(), )); if config.censorship.mask_host.is_some() { return Err(ProxyError::Config( - "mask_unix_sock and mask_host are mutually exclusive".to_string() + "mask_unix_sock and mask_host are mutually exclusive".to_string(), )); } } @@ -415,11 +452,11 @@ impl ProxyConfig { if config.censorship.mask_host.is_none() && config.censorship.mask_unix_sock.is_none() { config.censorship.mask_host = Some(config.censorship.tls_domain.clone()); } - + // Random fake_cert_len use rand::Rng; config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096); - + // Migration: Populate listeners if empty if config.server.listeners.is_empty() { if let Ok(ipv4) = config.server.listen_addr_ipv4.parse::() { @@ -429,7 +466,7 @@ impl ProxyConfig { }); } if let Some(ipv6_str) = &config.server.listen_addr_ipv6 { - if let Ok(ipv6) = ipv6_str.parse::() { + if let Ok(ipv6) = ipv6_str.parse::() { config.server.listeners.push(ListenerConfig { ip: ipv6, announce_ip: None, @@ -440,31 +477,32 @@ impl ProxyConfig { // Migration: Populate upstreams if empty (Default Direct) if config.upstreams.is_empty() { - config.upstreams.push(UpstreamConfig { + config.upstreams.push(UpstreamConfig { upstream_type: UpstreamType::Direct { interface: None }, weight: 1, enabled: true, }); } - + Ok(config) } - + pub fn validate(&self) -> Result<()> { if self.access.users.is_empty() { return Err(ProxyError::Config("No users configured".to_string())); } - + if !self.general.modes.classic && !self.general.modes.secure && !self.general.modes.tls { return Err(ProxyError::Config("No modes enabled".to_string())); } - + if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') { - return Err(ProxyError::Config( - format!("Invalid tls_domain: '{}'. Must be a valid domain name", self.censorship.tls_domain) - )); + return Err(ProxyError::Config(format!( + "Invalid tls_domain: '{}'. Must be a valid domain name", + self.censorship.tls_domain + ))); } Ok(()) } -} \ No newline at end of file +} diff --git a/src/main.rs b/src/main.rs index 5b8c491..e03f600 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,8 +6,8 @@ use std::time::Duration; use tokio::net::TcpListener; use tokio::signal; use tokio::sync::Semaphore; -use tracing::{info, error, warn, debug}; -use tracing_subscriber::{fmt, EnvFilter, reload, prelude::*}; +use tracing::{debug, error, info, warn}; +use tracing_subscriber::{EnvFilter, fmt, prelude::*, reload}; mod cli; mod config; @@ -20,14 +20,14 @@ mod stream; mod transport; mod util; -use crate::config::{ProxyConfig, LogLevel}; -use crate::proxy::ClientHandler; -use crate::stats::{Stats, ReplayChecker}; +use crate::config::{LogLevel, ProxyConfig}; use crate::crypto::SecureRandom; -use crate::transport::{create_listener, ListenOptions, UpstreamManager}; -use crate::transport::middle_proxy::MePool; -use crate::util::ip::detect_ip; +use crate::proxy::ClientHandler; +use crate::stats::{ReplayChecker, Stats}; use crate::stream::BufferPool; +use crate::transport::middle_proxy::MePool; +use crate::transport::{ListenOptions, UpstreamManager, create_listener}; +use crate::util::ip::detect_ip; fn parse_cli() -> (String, bool, Option) { let mut config_path = "config.toml".to_string(); @@ -48,10 +48,14 @@ fn parse_cli() -> (String, bool, Option) { let mut i = 0; while i < args.len() { match args[i].as_str() { - "--silent" | "-s" => { silent = true; } + "--silent" | "-s" => { + silent = true; + } "--log-level" => { i += 1; - if i < args.len() { log_level = Some(args[i].clone()); } + if i < args.len() { + log_level = Some(args[i].clone()); + } } s if s.starts_with("--log-level=") => { log_level = Some(s.trim_start_matches("--log-level=").to_string()); @@ -65,17 +69,27 @@ fn parse_cli() -> (String, bool, Option) { eprintln!(" --help, -h Show this help"); eprintln!(); eprintln!("Setup (fire-and-forget):"); - eprintln!(" --init Generate config, install systemd service, start"); + eprintln!( + " --init Generate config, install systemd service, start" + ); eprintln!(" --port Listen port (default: 443)"); - eprintln!(" --domain TLS domain for masking (default: www.google.com)"); - eprintln!(" --secret 32-char hex secret (auto-generated if omitted)"); + eprintln!( + " --domain TLS domain for masking (default: www.google.com)" + ); + eprintln!( + " --secret 32-char hex secret (auto-generated if omitted)" + ); eprintln!(" --user Username (default: user)"); eprintln!(" --config-dir Config directory (default: /etc/telemt)"); eprintln!(" --no-start Don't start the service after install"); std::process::exit(0); } - s if !s.starts_with('-') => { config_path = s.to_string(); } - other => { eprintln!("Unknown option: {}", other); } + s if !s.starts_with('-') => { + config_path = s.to_string(); + } + other => { + eprintln!("Unknown option: {}", other); + } } i += 1; } @@ -124,21 +138,30 @@ async fn main() -> std::result::Result<(), Box> { info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION")); info!("Log level: {}", effective_log_level); - info!("Modes: classic={} secure={} tls={}", - config.general.modes.classic, - config.general.modes.secure, - config.general.modes.tls); + info!( + "Modes: classic={} secure={} tls={}", + config.general.modes.classic, config.general.modes.secure, config.general.modes.tls + ); info!("TLS domain: {}", config.censorship.tls_domain); if let Some(ref sock) = config.censorship.mask_unix_sock { info!("Mask: {} -> unix:{}", config.censorship.mask, sock); if !std::path::Path::new(sock).exists() { - warn!("Unix socket '{}' does not exist yet. Masking will fail until it appears.", sock); + warn!( + "Unix socket '{}' does not exist yet. Masking will fail until it appears.", + sock + ); } } else { - info!("Mask: {} -> {}:{}", + info!( + "Mask: {} -> {}:{}", config.censorship.mask, - config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain), - config.censorship.mask_port); + config + .censorship + .mask_host + .as_deref() + .unwrap_or(&config.censorship.tls_domain), + config.censorship.mask_port + ); } if config.censorship.tls_domain == "www.google.com" { @@ -166,69 +189,78 @@ async fn main() -> std::result::Result<(), Box> { // Middle Proxy initialization (if enabled) // ===================================================================== let me_pool: Option> = if use_middle_proxy { - info!("=== Middle Proxy Mode ==="); + info!("=== Middle Proxy Mode ==="); - // ad_tag (proxy_tag) for advertising - let proxy_tag = config.general.ad_tag.as_ref().map(|tag| { - hex::decode(tag).unwrap_or_else(|_| { - warn!("Invalid ad_tag hex, middle proxy ad_tag will be empty"); - Vec::new() - }) - }); + // ad_tag (proxy_tag) for advertising + let proxy_tag = config.general.ad_tag.as_ref().map(|tag| { + hex::decode(tag).unwrap_or_else(|_| { + warn!("Invalid ad_tag hex, middle proxy ad_tag will be empty"); + Vec::new() + }) + }); - // ============================================================= - // CRITICAL: Download Telegram proxy-secret (NOT user secret!) - // - // C MTProxy uses TWO separate secrets: - // -S flag = 16-byte user secret for client obfuscation - // --aes-pwd = 32-512 byte binary file for ME RPC auth - // - // proxy-secret is from: https://core.telegram.org/getProxySecret - // ============================================================= - let proxy_secret_path = config.general.proxy_secret_path.as_deref(); - match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).await { - Ok(proxy_secret) => { - info!( - secret_len = proxy_secret.len(), - key_sig = format_args!("0x{:08x}", - if proxy_secret.len() >= 4 { - u32::from_le_bytes([proxy_secret[0], proxy_secret[1], - proxy_secret[2], proxy_secret[3]]) - } else { 0 }), - "Proxy-secret loaded" - ); - - let pool = MePool::new(proxy_tag, proxy_secret); - - match pool.init(2, &rng).await { - Ok(()) => { - info!("Middle-End pool initialized successfully"); - - // Phase 4: Start health monitor - let pool_clone = pool.clone(); - let rng_clone = rng.clone(); - tokio::spawn(async move { - crate::transport::middle_proxy::me_health_monitor( - pool_clone, rng_clone, 2, - ).await; - }); - - Some(pool) - } - Err(e) => { - error!(error = %e, "Failed to initialize ME pool. Falling back to direct mode."); - None + // ============================================================= + // CRITICAL: Download Telegram proxy-secret (NOT user secret!) + // + // C MTProxy uses TWO separate secrets: + // -S flag = 16-byte user secret for client obfuscation + // --aes-pwd = 32-512 byte binary file for ME RPC auth + // + // proxy-secret is from: https://core.telegram.org/getProxySecret + // ============================================================= + let proxy_secret_path = config.general.proxy_secret_path.as_deref(); + match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).await { + Ok(proxy_secret) => { + info!( + secret_len = proxy_secret.len(), + key_sig = format_args!( + "0x{:08x}", + if proxy_secret.len() >= 4 { + u32::from_le_bytes([ + proxy_secret[0], + proxy_secret[1], + proxy_secret[2], + proxy_secret[3], + ]) + } else { + 0 } + ), + "Proxy-secret loaded" + ); + + let pool = MePool::new(proxy_tag, proxy_secret, config.general.middle_proxy_nat_ip); + + match pool.init(2, &rng).await { + Ok(()) => { + info!("Middle-End pool initialized successfully"); + + // Phase 4: Start health monitor + let pool_clone = pool.clone(); + let rng_clone = rng.clone(); + tokio::spawn(async move { + crate::transport::middle_proxy::me_health_monitor( + pool_clone, rng_clone, 2, + ) + .await; + }); + + Some(pool) + } + Err(e) => { + error!(error = %e, "Failed to initialize ME pool. Falling back to direct mode."); + None } } - Err(e) => { - error!(error = %e, "Failed to fetch proxy-secret. Falling back to direct mode."); - None - } } - } else { - None - }; + Err(e) => { + error!(error = %e, "Failed to fetch proxy-secret. Falling back to direct mode."); + None + } + } + } else { + None + }; if me_pool.is_some() { info!("Transport: Middle Proxy (supports all DCs including CDN)"); @@ -251,8 +283,14 @@ async fn main() -> std::result::Result<(), Box> { info!(" IPv4 in use and IPv6 is fallback"); } } else { - let v6_works = upstream_result.v6_results.iter().any(|r| r.rtt_ms.is_some()); - let v4_works = upstream_result.v4_results.iter().any(|r| r.rtt_ms.is_some()); + let v6_works = upstream_result + .v6_results + .iter() + .any(|r| r.rtt_ms.is_some()); + let v4_works = upstream_result + .v4_results + .iter() + .any(|r| r.rtt_ms.is_some()); if v6_works && !v4_works { info!(" IPv6 only (IPv4 unavailable)"); } else if v4_works && !v6_works { @@ -290,11 +328,17 @@ async fn main() -> std::result::Result<(), Box> { Some(rtt) => { // Align: IPv4 addresses are shorter, use more tabs // 149.154.175.50:443 = ~18 chars - info!(" DC{} [IPv4] {}:\t\t\t\t{:.0} ms", dc.dc_idx, addr_str, rtt); + info!( + " DC{} [IPv4] {}:\t\t\t\t{:.0} ms", + dc.dc_idx, addr_str, rtt + ); } None => { let err = dc.error.as_deref().unwrap_or("fail"); - info!(" DC{} [IPv4] {}:\t\t\t\tFAIL ({})", dc.dc_idx, addr_str, err); + info!( + " DC{} [IPv4] {}:\t\t\t\tFAIL ({})", + dc.dc_idx, addr_str, err + ); } } } @@ -305,13 +349,20 @@ async fn main() -> std::result::Result<(), Box> { // Background tasks let um_clone = upstream_manager.clone(); - tokio::spawn(async move { um_clone.run_health_checks(prefer_ipv6).await; }); + tokio::spawn(async move { + um_clone.run_health_checks(prefer_ipv6).await; + }); let rc_clone = replay_checker.clone(); - tokio::spawn(async move { rc_clone.run_periodic_cleanup().await; }); + tokio::spawn(async move { + rc_clone.run_periodic_cleanup().await; + }); let detected_ip = detect_ip().await; - debug!("Detected IPs: v4={:?} v6={:?}", detected_ip.ipv4, detected_ip.ipv6); + debug!( + "Detected IPs: v4={:?} v6={:?}", + detected_ip.ipv4, detected_ip.ipv6 + ); let mut listeners = Vec::new(); @@ -345,17 +396,23 @@ async fn main() -> std::result::Result<(), Box> { if let Some(secret) = config.access.users.get(user_name) { info!("User: {}", user_name); if config.general.modes.classic { - info!(" Classic: tg://proxy?server={}&port={}&secret={}", - public_ip, config.server.port, secret); + info!( + " Classic: tg://proxy?server={}&port={}&secret={}", + public_ip, config.server.port, secret + ); } if config.general.modes.secure { - info!(" DD: tg://proxy?server={}&port={}&secret=dd{}", - public_ip, config.server.port, secret); + info!( + " DD: tg://proxy?server={}&port={}&secret=dd{}", + public_ip, config.server.port, secret + ); } if config.general.modes.tls { let domain_hex = hex::encode(&config.censorship.tls_domain); - info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", - public_ip, config.server.port, secret, domain_hex); + info!( + " EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", + public_ip, config.server.port, secret, domain_hex + ); } } else { warn!("User '{}' in show_link not found", user_name); @@ -365,7 +422,7 @@ async fn main() -> std::result::Result<(), Box> { } listeners.push(listener); - }, + } Err(e) => { error!("Failed to bind to {}: {}", addr, e); } @@ -383,7 +440,9 @@ async fn main() -> std::result::Result<(), Box> { } else { EnvFilter::new(effective_log_level.to_filter_str()) }; - filter_handle.reload(runtime_filter).expect("Failed to switch log filter"); + filter_handle + .reload(runtime_filter) + .expect("Failed to switch log filter"); for listener in listeners { let config = config.clone(); @@ -408,10 +467,19 @@ async fn main() -> std::result::Result<(), Box> { tokio::spawn(async move { if let Err(e) = ClientHandler::new( - stream, peer_addr, config, stats, - upstream_manager, replay_checker, buffer_pool, rng, + stream, + peer_addr, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, me_pool, - ).run().await { + ) + .run() + .await + { debug!(peer = %peer_addr, error = %e, "Connection error"); } }); @@ -431,4 +499,4 @@ async fn main() -> std::result::Result<(), Box> { } Ok(()) -} \ No newline at end of file +} diff --git a/src/proxy/client.rs b/src/proxy/client.rs index caceb33..726d238 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -1,661 +1,354 @@ //! Client Handler - - use std::net::SocketAddr; - use std::sync::Arc; - use std::time::Duration; - use tokio::net::TcpStream; - use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; - use tokio::time::timeout; - use tracing::{debug, info, warn, error, trace}; - - use crate::config::ProxyConfig; - use crate::error::{ProxyError, Result, HandshakeResult}; - use crate::protocol::constants::*; - use crate::protocol::tls; - use crate::stats::{Stats, ReplayChecker}; - use crate::transport::{configure_client_socket, UpstreamManager}; - use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; - use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter, BufferPool}; - use crate::crypto::{AesCtr, SecureRandom}; - - use crate::proxy::handshake::{ - handle_tls_handshake, handle_mtproto_handshake, - HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce_with_ciphers, - }; - use crate::proxy::relay::relay_bidirectional; - use crate::proxy::masking::handle_bad_client; - - pub struct ClientHandler; - - pub struct RunningClientHandler { + +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; +use tokio::net::TcpStream; +use tokio::time::timeout; +use tracing::{debug, warn}; + +use crate::config::ProxyConfig; +use crate::crypto::SecureRandom; +use crate::error::{HandshakeResult, ProxyError, Result}; +use crate::protocol::constants::*; +use crate::protocol::tls; +use crate::stats::{ReplayChecker, Stats}; +use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; +use crate::transport::middle_proxy::MePool; +use crate::transport::{UpstreamManager, configure_client_socket}; + +use crate::proxy::direct_relay::handle_via_direct; +use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake}; +use crate::proxy::masking::handle_bad_client; +use crate::proxy::middle_relay::handle_via_middle_proxy; + +pub struct ClientHandler; + +pub struct RunningClientHandler { + stream: TcpStream, + peer: SocketAddr, + config: Arc, + stats: Arc, + replay_checker: Arc, + upstream_manager: Arc, + buffer_pool: Arc, + rng: Arc, + me_pool: Option>, +} + +impl ClientHandler { + pub fn new( stream: TcpStream, peer: SocketAddr, config: Arc, stats: Arc, - replay_checker: Arc, upstream_manager: Arc, + replay_checker: Arc, buffer_pool: Arc, rng: Arc, me_pool: Option>, + ) -> RunningClientHandler { + RunningClientHandler { + stream, + peer, + config, + stats, + replay_checker, + upstream_manager, + buffer_pool, + rng, + me_pool, + } } - - impl ClientHandler { - pub fn new( - stream: TcpStream, - peer: SocketAddr, - config: Arc, - stats: Arc, - upstream_manager: Arc, - replay_checker: Arc, - buffer_pool: Arc, - rng: Arc, - me_pool: Option>, - ) -> RunningClientHandler { - RunningClientHandler { - stream, peer, config, stats, replay_checker, - upstream_manager, buffer_pool, rng, me_pool, +} + +impl RunningClientHandler { + pub async fn run(mut self) -> Result<()> { + self.stats.increment_connects_all(); + + let peer = self.peer; + debug!(peer = %peer, "New connection"); + + if let Err(e) = configure_client_socket( + &self.stream, + self.config.timeouts.client_keepalive, + self.config.timeouts.client_ack, + ) { + debug!(peer = %peer, error = %e, "Failed to configure client socket"); + } + + let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake); + let stats = self.stats.clone(); + + let result = timeout(handshake_timeout, self.do_handshake()).await; + + match result { + Ok(Ok(())) => { + debug!(peer = %peer, "Connection handled successfully"); + Ok(()) + } + Ok(Err(e)) => { + debug!(peer = %peer, error = %e, "Handshake failed"); + Err(e) + } + Err(_) => { + stats.increment_handshake_timeouts(); + debug!(peer = %peer, "Handshake timeout"); + Err(ProxyError::TgHandshakeTimeout) } } } - - impl RunningClientHandler { - pub async fn run(mut self) -> Result<()> { - self.stats.increment_connects_all(); - - let peer = self.peer; - debug!(peer = %peer, "New connection"); - - if let Err(e) = configure_client_socket( - &self.stream, - self.config.timeouts.client_keepalive, - self.config.timeouts.client_ack, - ) { - debug!(peer = %peer, error = %e, "Failed to configure client socket"); - } - - let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake); - let stats = self.stats.clone(); - - let result = timeout(handshake_timeout, self.do_handshake()).await; - - match result { - Ok(Ok(())) => { - debug!(peer = %peer, "Connection handled successfully"); - Ok(()) - } - Ok(Err(e)) => { - debug!(peer = %peer, error = %e, "Handshake failed"); - Err(e) - } - Err(_) => { - stats.increment_handshake_timeouts(); - debug!(peer = %peer, "Handshake timeout"); - Err(ProxyError::TgHandshakeTimeout) - } - } + + async fn do_handshake(mut self) -> Result<()> { + let mut first_bytes = [0u8; 5]; + self.stream.read_exact(&mut first_bytes).await?; + + let is_tls = tls::is_tls_handshake(&first_bytes[..3]); + let peer = self.peer; + + debug!(peer = %peer, is_tls = is_tls, "Handshake type detected"); + + if is_tls { + self.handle_tls_client(first_bytes).await + } else { + self.handle_direct_client(first_bytes).await } - - async fn do_handshake(mut self) -> Result<()> { - let mut first_bytes = [0u8; 5]; - self.stream.read_exact(&mut first_bytes).await?; - - let is_tls = tls::is_tls_handshake(&first_bytes[..3]); - let peer = self.peer; - - debug!(peer = %peer, is_tls = is_tls, "Handshake type detected"); - - if is_tls { - self.handle_tls_client(first_bytes).await - } else { - self.handle_direct_client(first_bytes).await - } + } + + async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> { + let peer = self.peer; + + let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; + + debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake"); + + if tls_len < 512 { + debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); + self.stats.increment_connects_bad(); + let (reader, writer) = self.stream.into_split(); + handle_bad_client(reader, writer, &first_bytes, &self.config).await; + return Ok(()); } - - async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> { - let peer = self.peer; - - let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; - - debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake"); - - if tls_len < 512 { - debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); - self.stats.increment_connects_bad(); - let (reader, writer) = self.stream.into_split(); - handle_bad_client(reader, writer, &first_bytes, &self.config).await; + + let mut handshake = vec![0u8; 5 + tls_len]; + handshake[..5].copy_from_slice(&first_bytes); + self.stream.read_exact(&mut handshake[5..]).await?; + + let config = self.config.clone(); + let replay_checker = self.replay_checker.clone(); + let stats = self.stats.clone(); + let buffer_pool = self.buffer_pool.clone(); + + let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; + let (read_half, write_half) = self.stream.into_split(); + + let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( + &handshake, + read_half, + write_half, + peer, + &config, + &replay_checker, + &self.rng, + ) + .await + { + HandshakeResult::Success(result) => result, + HandshakeResult::BadClient { reader, writer } => { + stats.increment_connects_bad(); + handle_bad_client(reader, writer, &handshake, &config).await; return Ok(()); } - - let mut handshake = vec![0u8; 5 + tls_len]; - handshake[..5].copy_from_slice(&first_bytes); - self.stream.read_exact(&mut handshake[5..]).await?; - - let config = self.config.clone(); - let replay_checker = self.replay_checker.clone(); - let stats = self.stats.clone(); - let buffer_pool = self.buffer_pool.clone(); - - let (read_half, write_half) = self.stream.into_split(); - - let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( - &handshake, read_half, write_half, peer, - &config, &replay_checker, &self.rng, - ).await { - HandshakeResult::Success(result) => result, - HandshakeResult::BadClient { reader, writer } => { - stats.increment_connects_bad(); - handle_bad_client(reader, writer, &handshake, &config).await; - return Ok(()); - } - HandshakeResult::Error(e) => return Err(e), - }; - - debug!(peer = %peer, "Reading MTProto handshake through TLS"); - let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?; - let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into() - .map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?; - - let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( - &mtproto_handshake, tls_reader, tls_writer, peer, - &config, &replay_checker, true, - ).await { - HandshakeResult::Success(result) => result, - HandshakeResult::BadClient { reader: _, writer: _ } => { - stats.increment_connects_bad(); - debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); - return Ok(()); - } - HandshakeResult::Error(e) => return Err(e), - }; - - Self::handle_authenticated_static( - crypto_reader, crypto_writer, success, - self.upstream_manager, self.stats, self.config, - buffer_pool, self.rng, self.me_pool, - ).await - } - - async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> { - let peer = self.peer; - - if !self.config.general.modes.classic && !self.config.general.modes.secure { - debug!(peer = %peer, "Non-TLS modes disabled"); - self.stats.increment_connects_bad(); - let (reader, writer) = self.stream.into_split(); - handle_bad_client(reader, writer, &first_bytes, &self.config).await; + HandshakeResult::Error(e) => return Err(e), + }; + + debug!(peer = %peer, "Reading MTProto handshake through TLS"); + let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?; + let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..] + .try_into() + .map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?; + + let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( + &mtproto_handshake, + tls_reader, + tls_writer, + peer, + &config, + &replay_checker, + true, + ) + .await + { + HandshakeResult::Success(result) => result, + HandshakeResult::BadClient { + reader: _, + writer: _, + } => { + stats.increment_connects_bad(); + debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); return Ok(()); } - - let mut handshake = [0u8; HANDSHAKE_LEN]; - handshake[..5].copy_from_slice(&first_bytes); - self.stream.read_exact(&mut handshake[5..]).await?; - - let config = self.config.clone(); - let replay_checker = self.replay_checker.clone(); - let stats = self.stats.clone(); - let buffer_pool = self.buffer_pool.clone(); - - let (read_half, write_half) = self.stream.into_split(); - - let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( - &handshake, read_half, write_half, peer, - &config, &replay_checker, false, - ).await { - HandshakeResult::Success(result) => result, - HandshakeResult::BadClient { reader, writer } => { - stats.increment_connects_bad(); - handle_bad_client(reader, writer, &handshake, &config).await; - return Ok(()); - } - HandshakeResult::Error(e) => return Err(e), - }; - - Self::handle_authenticated_static( - crypto_reader, crypto_writer, success, - self.upstream_manager, self.stats, self.config, - buffer_pool, self.rng, self.me_pool, - ).await + HandshakeResult::Error(e) => return Err(e), + }; + + Self::handle_authenticated_static( + crypto_reader, + crypto_writer, + success, + self.upstream_manager, + self.stats, + self.config, + buffer_pool, + self.rng, + self.me_pool, + local_addr, + ) + .await + } + + async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> { + let peer = self.peer; + + if !self.config.general.modes.classic && !self.config.general.modes.secure { + debug!(peer = %peer, "Non-TLS modes disabled"); + self.stats.increment_connects_bad(); + let (reader, writer) = self.stream.into_split(); + handle_bad_client(reader, writer, &first_bytes, &self.config).await; + return Ok(()); } - - /// Main dispatch after successful handshake. - /// Two modes: - /// - Direct: TCP relay to TG DC (existing behavior) - /// - Middle Proxy: RPC multiplex through ME pool (new — supports CDN DCs) - async fn handle_authenticated_static( - client_reader: CryptoReader, - client_writer: CryptoWriter, - success: HandshakeSuccess, - upstream_manager: Arc, - stats: Arc, - config: Arc, - buffer_pool: Arc, - rng: Arc, - me_pool: Option>, - ) -> Result<()> - where - R: AsyncRead + Unpin + Send + 'static, - W: AsyncWrite + Unpin + Send + 'static, + + let mut handshake = [0u8; HANDSHAKE_LEN]; + handshake[..5].copy_from_slice(&first_bytes); + self.stream.read_exact(&mut handshake[5..]).await?; + + let config = self.config.clone(); + let replay_checker = self.replay_checker.clone(); + let stats = self.stats.clone(); + let buffer_pool = self.buffer_pool.clone(); + + let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; + let (read_half, write_half) = self.stream.into_split(); + + let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( + &handshake, + read_half, + write_half, + peer, + &config, + &replay_checker, + false, + ) + .await { - let user = &success.user; - - if let Err(e) = Self::check_user_limits_static(user, &config, &stats) { - warn!(user = %user, error = %e, "User limit exceeded"); - return Err(e); + HandshakeResult::Success(result) => result, + HandshakeResult::BadClient { reader, writer } => { + stats.increment_connects_bad(); + handle_bad_client(reader, writer, &handshake, &config).await; + return Ok(()); } - - // Decide: middle proxy or direct - if config.general.use_middle_proxy { - if let Some(ref pool) = me_pool { - return Self::handle_via_middle_proxy( - client_reader, client_writer, success, - pool.clone(), stats, config, buffer_pool, - ).await; - } - warn!("use_middle_proxy=true but MePool not initialized, falling back to direct"); - } - - // Direct mode (original behavior) - Self::handle_via_direct( - client_reader, client_writer, success, - upstream_manager, stats, config, buffer_pool, rng, - ).await - } - - // ===================================================================== - // Direct mode — TCP relay to Telegram DC - // ===================================================================== - - async fn handle_via_direct( - client_reader: CryptoReader, - client_writer: CryptoWriter, - success: HandshakeSuccess, - upstream_manager: Arc, - stats: Arc, - config: Arc, - buffer_pool: Arc, - rng: Arc, - ) -> Result<()> - where - R: AsyncRead + Unpin + Send + 'static, - W: AsyncWrite + Unpin + Send + 'static, - { - let user = &success.user; - let dc_addr = Self::get_dc_addr_static(success.dc_idx, &config)?; - - info!( - user = %user, - peer = %success.peer, - dc = success.dc_idx, - dc_addr = %dc_addr, - proto = ?success.proto_tag, - mode = "direct", - "Connecting to Telegram DC" - ); - - let tg_stream = upstream_manager.connect(dc_addr, Some(success.dc_idx)).await?; - - debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake"); - - let (tg_reader, tg_writer) = Self::do_tg_handshake_static( - tg_stream, &success, &config, rng.as_ref(), - ).await?; - - debug!(peer = %success.peer, "TG handshake complete, starting relay"); - - stats.increment_user_connects(user); - stats.increment_user_curr_connects(user); - - let relay_result = relay_bidirectional( - client_reader, client_writer, - tg_reader, tg_writer, - user, Arc::clone(&stats), buffer_pool, - ).await; - - stats.decrement_user_curr_connects(user); - - match &relay_result { - Ok(()) => debug!(user = %user, "Direct relay completed"), - Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"), - } - - relay_result - } - - // ===================================================================== - // Middle Proxy mode — RPC multiplex through ME pool - // ===================================================================== - - /// Middle Proxy RPC relay - /// - /// Architecture (matches C MTProxy): - /// ```text - /// Client ←AES-CTR→ [telemt] ←RPC/AES-CBC→ ME ←internal→ DC (any, incl CDN 203) - /// ``` - /// - /// Key difference from direct mode: - /// - No per-client TCP to DC; all clients share ME pool connections - /// - ME internally routes to correct DC based on client's encrypted auth_key_id - /// - CDN DCs (203+) work because ME knows their internal addresses - /// - We pass raw client MTProto bytes in RPC_PROXY_REQ envelope - /// - ME returns responses in RPC_PROXY_ANS envelope - - async fn handle_via_middle_proxy( - crypto_reader: CryptoReader, - crypto_writer: CryptoWriter, + HandshakeResult::Error(e) => return Err(e), + }; + + Self::handle_authenticated_static( + crypto_reader, + crypto_writer, + success, + self.upstream_manager, + self.stats, + self.config, + buffer_pool, + self.rng, + self.me_pool, + local_addr, + ) + .await + } + + /// Main dispatch after successful handshake. + /// Two modes: + /// - Direct: TCP relay to TG DC (existing behavior) + /// - Middle Proxy: RPC multiplex through ME pool (new — supports CDN DCs) + async fn handle_authenticated_static( + client_reader: CryptoReader, + client_writer: CryptoWriter, success: HandshakeSuccess, - me_pool: Arc, + upstream_manager: Arc, stats: Arc, config: Arc, - _buffer_pool: Arc, + buffer_pool: Arc, + rng: Arc, + me_pool: Option>, + local_addr: SocketAddr, ) -> Result<()> where R: AsyncRead + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static, { - let mut client_reader = crypto_reader; - let mut client_writer = crypto_writer; - - let user = success.user.clone(); - let peer = success.peer; + let user = &success.user; - info!( - user = %user, - peer = %peer, - dc = success.dc_idx, - proto = ?success.proto_tag, - mode = "middle_proxy", - "Routing via Middle-End" - ); + if let Err(e) = Self::check_user_limits_static(user, &config, &stats) { + warn!(user = %user, error = %e, "User limit exceeded"); + return Err(e); + } - let (conn_id, mut me_rx) = me_pool.registry().register().await; - - let our_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port) - .parse().unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap()); - - stats.increment_user_connects(&user); - stats.increment_user_curr_connects(&user); - - let proto_flags = proto_flags_for_tag(success.proto_tag); - debug!(user = %user, conn_id, proto_flags = format_args!("0x{:08x}", proto_flags), "ME relay started"); - - // We need to handle framing here. - // Client sends: [Len:4][Payload...] (Intermediate/Secure) - // We must strip Len and send Payload to ME. - // ME sends: [Payload...] - // We must add [Len:4] and send to Client. - - // For Secure mode, Len has padding bit (MSB). - let is_secure = success.proto_tag == crate::protocol::constants::ProtoTag::Secure; - - let mut client_closed = false; - let mut server_closed = false; - - // Split client_reader/writer to use in select! - // CryptoReader/Writer don't support splitting easily without Arc/Mutex or unsafe, - // but here we are in a loop. - // We can't easily split them because they wrap the underlying stream. - // However, we can use a loop with select! on read and rx. - - let mut len_buf = [0u8; 4]; - let mut reading_len = true; - let mut current_payload_len = 0; - let mut payload_buf = Vec::new(); - - let result: Result<()> = loop { - tokio::select! { - // C->S: Read length, then payload - res = async { - if reading_len { - client_reader.read_exact(&mut len_buf).await.map(|_| true) - } else { - // Read payload - // We need to read exactly current_payload_len - if payload_buf.len() < current_payload_len { - let needed = current_payload_len - payload_buf.len(); - let mut chunk = vec![0u8; needed]; - let n = client_reader.read(&mut chunk).await?; - if n == 0 { return Ok(false); } // EOF - payload_buf.extend_from_slice(&chunk[..n]); - Ok(true) - } else { - Ok(true) // Should not happen - } - } - }, if !client_closed => { - match res { - Ok(true) => { - if reading_len { - // Got length - let raw_len = u32::from_le_bytes(len_buf); - // In secure mode, MSB is padding flag. In intermediate, it's just len. - // But wait, standard intermediate doesn't use MSB for padding. - // Secure mode DOES. - // Let's trust the protocol tag. - let len = if is_secure { - raw_len & 0x7FFFFFFF - } else { - raw_len - }; - - current_payload_len = len as usize; - // Sanity check - if current_payload_len > 16 * 1024 * 1024 { - debug!(conn_id, len=current_payload_len, "Client sent huge frame"); - break Err(ProxyError::Proxy("Frame too large".into())); - } - payload_buf.clear(); - payload_buf.reserve(current_payload_len); - reading_len = false; - } else { - // Got some payload data - if payload_buf.len() == current_payload_len { - // Full frame received - trace!(conn_id, bytes = current_payload_len, "C->ME (Frame complete)"); - stats.add_user_octets_from(&user, current_payload_len as u64); - - // Send to ME - // Note: In secure mode, we send the PADDING bytes too? - // Erlang mtp_intermediate: strips 4 bytes len. - // Erlang mtp_secure: strips 4 bytes len. - // The payload includes the padding if it was added? - // Actually, secure layer (mtp_secure.erl) handles padding removal? - // No, mtp_secure just sets padding=>true for intermediate codec. - // The intermediate codec (mtp_intermediate.erl) just extracts the packet. - // The packet passed to RPC is the payload. - // If secure mode adds random padding at the end, it is part of the payload - // that ME receives? - // Let's look at C code. - // ext-server.c: reads packet_len. - // if (packet_len & 0x80000000) -> has padding. - // It reads the full packet. - // Then it passes it to forward_tcp_query. - // So YES, we send the full payload including padding to ME. - - if let Err(e) = me_pool.send_proxy_req( - conn_id, peer, our_addr, &payload_buf, proto_flags - ).await { - break Err(e); - } - - // Reset for next frame - reading_len = true; - } - } - } - Ok(false) => { - // EOF - debug!(conn_id, "Client EOF"); - client_closed = true; - let _ = me_pool.send_close(conn_id).await; - if server_closed { break Ok(()); } - } - Err(e) => { - debug!(conn_id, error = %e, "Client read error"); - break Err(ProxyError::Io(e)); - } - } - } - - // S->C: ME sends data, we wrap and send to client - me_msg = me_rx.recv(), if !server_closed => { - match me_msg { - Some(MeResponse::Data(data)) => { - trace!(conn_id, bytes = data.len(), "ME->C"); - stats.add_user_octets_to(&user, data.len() as u64); - - // Wrap in intermediate frame - let len = data.len() as u32; - // For secure mode, we might need to add padding? - // C code: forward_mtproto_packet -> just sends data. - // But wait, C code adds framing in net-tcp-rpc-ext-server.c? - // No, forward_tcp_query sends RPC_PROXY_REQ. - // ME sends RPC_PROXY_ANS. - // The data in ANS is the MTProto packet. - // We need to send it to client. - // If client is Intermediate/Secure, we MUST add the 4-byte length prefix. - // Secure mode: usually we don't ADD padding on response, we just send valid packets. - // But we MUST send the length. - - if let Err(e) = client_writer.write_all(&len.to_le_bytes()).await { - break Err(ProxyError::Io(e)); - } - if let Err(e) = client_writer.write_all(&data).await { - break Err(ProxyError::Io(e)); - } - if let Err(e) = client_writer.flush().await { - break Err(ProxyError::Io(e)); - } - } - Some(MeResponse::Ack(_)) => { - trace!(conn_id, "ME ACK"); - } - Some(MeResponse::Close) => { - debug!(conn_id, "ME sent CLOSE"); - server_closed = true; - if client_closed { break Ok(()); } - // We should probably close client connection too - break Ok(()); - } - None => { - debug!(conn_id, "ME channel closed"); - server_closed = true; - if client_closed { break Ok(()); } - break Err(ProxyError::Proxy("ME connection lost".into())); - } - } - } + // Decide: middle proxy or direct + if config.general.use_middle_proxy { + if let Some(ref pool) = me_pool { + return handle_via_middle_proxy( + client_reader, + client_writer, + success, + pool.clone(), + stats, + config, + buffer_pool, + local_addr, + ) + .await; } - }; + warn!("use_middle_proxy=true but MePool not initialized, falling back to direct"); + } - // Cleanup - debug!(user = %user, conn_id, "ME relay cleanup"); - me_pool.registry().unregister(conn_id).await; - stats.decrement_user_curr_connects(&user); - result + // Direct mode (original behavior) + handle_via_direct( + client_reader, + client_writer, + success, + upstream_manager, + stats, + config, + buffer_pool, + rng, + ) + .await } - fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> { - if let Some(expiration) = config.access.user_expirations.get(user) { - if chrono::Utc::now() > *expiration { - return Err(ProxyError::UserExpired { user: user.to_string() }); - } + if let Some(expiration) = config.access.user_expirations.get(user) { + if chrono::Utc::now() > *expiration { + return Err(ProxyError::UserExpired { + user: user.to_string(), + }); } - - if let Some(limit) = config.access.user_max_tcp_conns.get(user) { - if stats.get_user_curr_connects(user) >= *limit as u64 { - return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() }); - } - } - - if let Some(quota) = config.access.user_data_quota.get(user) { - if stats.get_user_total_octets(user) >= *quota { - return Err(ProxyError::DataQuotaExceeded { user: user.to_string() }); - } - } - - Ok(()) } - - /// Resolve DC index to target address (used only in direct mode) - fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { - let datacenters = if config.general.prefer_ipv6 { - &*TG_DATACENTERS_V6 - } else { - &*TG_DATACENTERS_V4 - }; - - let num_dcs = datacenters.len(); - - let dc_key = dc_idx.to_string(); - if let Some(addr_str) = config.dc_overrides.get(&dc_key) { - match addr_str.parse::() { - Ok(addr) => { - debug!(dc_idx = dc_idx, addr = %addr, "Using DC override from config"); - return Ok(addr); - } - Err(_) => { - warn!(dc_idx = dc_idx, addr_str = %addr_str, - "Invalid DC override address in config, ignoring"); - } - } - } - - let abs_dc = dc_idx.unsigned_abs() as usize; - if abs_dc >= 1 && abs_dc <= num_dcs { - return Ok(SocketAddr::new(datacenters[abs_dc - 1], TG_DATACENTER_PORT)); - } - - let default_dc = config.default_dc.unwrap_or(2) as usize; - let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs { - default_dc - 1 - } else { - 1 - }; - - info!( - original_dc = dc_idx, - fallback_dc = (fallback_idx + 1) as u16, - fallback_addr = %datacenters[fallback_idx], - "Special DC ---> default_cluster" - ); - - Ok(SocketAddr::new(datacenters[fallback_idx], TG_DATACENTER_PORT)) - } - - /// Perform obfuscated handshake with Telegram DC (direct mode only) - async fn do_tg_handshake_static( - mut stream: TcpStream, - success: &HandshakeSuccess, - config: &ProxyConfig, - rng: &SecureRandom, - ) -> Result<(CryptoReader, CryptoWriter)> { - let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce( - success.proto_tag, - success.dc_idx, - &success.dec_key, - success.dec_iv, - rng, - config.general.fast_mode, - ); - let (encrypted_nonce, mut tg_encryptor, tg_decryptor) = encrypt_tg_nonce_with_ciphers(&nonce); - - debug!( - peer = %success.peer, - nonce_head = %hex::encode(&nonce[..16]), - "Sending nonce to Telegram" - ); - - stream.write_all(&encrypted_nonce).await?; - stream.flush().await?; - - let (read_half, write_half) = stream.into_split(); - - Ok(( - CryptoReader::new(read_half, tg_decryptor), - CryptoWriter::new(write_half, tg_encryptor), - )) + if let Some(limit) = config.access.user_max_tcp_conns.get(user) { + if stats.get_user_curr_connects(user) >= *limit as u64 { + return Err(ProxyError::ConnectionLimitExceeded { + user: user.to_string(), + }); + } } + + if let Some(quota) = config.access.user_data_quota.get(user) { + if stats.get_user_total_octets(user) >= *quota { + return Err(ProxyError::DataQuotaExceeded { + user: user.to_string(), + }); + } + } + + Ok(()) } - \ No newline at end of file +} diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs new file mode 100644 index 0000000..46c004c --- /dev/null +++ b/src/proxy/direct_relay.rs @@ -0,0 +1,163 @@ +use std::net::SocketAddr; +use std::sync::Arc; + +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::net::TcpStream; +use tracing::{debug, info, warn}; + +use crate::config::ProxyConfig; +use crate::crypto::SecureRandom; +use crate::error::Result; +use crate::protocol::constants::*; +use crate::proxy::handshake::{HandshakeSuccess, encrypt_tg_nonce_with_ciphers, generate_tg_nonce}; +use crate::proxy::relay::relay_bidirectional; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; +use crate::transport::UpstreamManager; + +pub(crate) async fn handle_via_direct( + client_reader: CryptoReader, + client_writer: CryptoWriter, + success: HandshakeSuccess, + upstream_manager: Arc, + stats: Arc, + config: Arc, + buffer_pool: Arc, + rng: Arc, +) -> Result<()> +where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, +{ + let user = &success.user; + let dc_addr = get_dc_addr_static(success.dc_idx, &config)?; + + info!( + user = %user, + peer = %success.peer, + dc = success.dc_idx, + dc_addr = %dc_addr, + proto = ?success.proto_tag, + mode = "direct", + "Connecting to Telegram DC" + ); + + let tg_stream = upstream_manager + .connect(dc_addr, Some(success.dc_idx)) + .await?; + + debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake"); + + let (tg_reader, tg_writer) = + do_tg_handshake_static(tg_stream, &success, &config, rng.as_ref()).await?; + + debug!(peer = %success.peer, "TG handshake complete, starting relay"); + + stats.increment_user_connects(user); + stats.increment_user_curr_connects(user); + + let relay_result = relay_bidirectional( + client_reader, + client_writer, + tg_reader, + tg_writer, + user, + Arc::clone(&stats), + buffer_pool, + ) + .await; + + stats.decrement_user_curr_connects(user); + + match &relay_result { + Ok(()) => debug!(user = %user, "Direct relay completed"), + Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"), + } + + relay_result +} + +fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { + let datacenters = if config.general.prefer_ipv6 { + &*TG_DATACENTERS_V6 + } else { + &*TG_DATACENTERS_V4 + }; + + let num_dcs = datacenters.len(); + + let dc_key = dc_idx.to_string(); + if let Some(addr_str) = config.dc_overrides.get(&dc_key) { + match addr_str.parse::() { + Ok(addr) => { + debug!(dc_idx = dc_idx, addr = %addr, "Using DC override from config"); + return Ok(addr); + } + Err(_) => { + warn!(dc_idx = dc_idx, addr_str = %addr_str, + "Invalid DC override address in config, ignoring"); + } + } + } + + let abs_dc = dc_idx.unsigned_abs() as usize; + if abs_dc >= 1 && abs_dc <= num_dcs { + return Ok(SocketAddr::new(datacenters[abs_dc - 1], TG_DATACENTER_PORT)); + } + + let default_dc = config.default_dc.unwrap_or(2) as usize; + let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs { + default_dc - 1 + } else { + 1 + }; + + info!( + original_dc = dc_idx, + fallback_dc = (fallback_idx + 1) as u16, + fallback_addr = %datacenters[fallback_idx], + "Special DC ---> default_cluster" + ); + + Ok(SocketAddr::new( + datacenters[fallback_idx], + TG_DATACENTER_PORT, + )) +} + +async fn do_tg_handshake_static( + mut stream: TcpStream, + success: &HandshakeSuccess, + config: &ProxyConfig, + rng: &SecureRandom, +) -> Result<( + CryptoReader, + CryptoWriter, +)> { + let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce( + success.proto_tag, + success.dc_idx, + &success.dec_key, + success.dec_iv, + rng, + config.general.fast_mode, + ); + + let (encrypted_nonce, tg_encryptor, tg_decryptor) = encrypt_tg_nonce_with_ciphers(&nonce); + + debug!( + peer = %success.peer, + nonce_head = %hex::encode(&nonce[..16]), + "Sending nonce to Telegram" + ); + + stream.write_all(&encrypted_nonce).await?; + stream.flush().await?; + + let (read_half, write_half) = stream.into_split(); + + Ok(( + CryptoReader::new(read_half, tg_decryptor), + CryptoWriter::new(write_half, tg_encryptor), + )) +} diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs new file mode 100644 index 0000000..fe15d32 --- /dev/null +++ b/src/proxy/middle_relay.rs @@ -0,0 +1,247 @@ +use std::net::SocketAddr; +use std::sync::Arc; + +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tracing::{debug, info, trace}; + +use crate::config::ProxyConfig; +use crate::error::{ProxyError, Result}; +use crate::protocol::constants::*; +use crate::proxy::handshake::HandshakeSuccess; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; +use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; + +pub(crate) async fn handle_via_middle_proxy( + mut crypto_reader: CryptoReader, + mut crypto_writer: CryptoWriter, + success: HandshakeSuccess, + me_pool: Arc, + stats: Arc, + _config: Arc, + _buffer_pool: Arc, + local_addr: SocketAddr, +) -> Result<()> +where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, +{ + let user = success.user.clone(); + let peer = success.peer; + let proto_tag = success.proto_tag; + + info!( + user = %user, + peer = %peer, + dc = success.dc_idx, + proto = ?proto_tag, + mode = "middle_proxy", + "Routing via Middle-End" + ); + + let (conn_id, mut me_rx) = me_pool.registry().register().await; + + stats.increment_user_connects(&user); + stats.increment_user_curr_connects(&user); + + let proto_flags = proto_flags_for_tag(proto_tag, me_pool.has_proxy_tag()); + debug!( + user = %user, + conn_id, + proto_flags = format_args!("0x{:08x}", proto_flags), + "ME relay started" + ); + + let translated_local_addr = me_pool.translate_our_addr(local_addr); + + let result: Result<()> = loop { + tokio::select! { + client_frame = read_client_payload(&mut crypto_reader, proto_tag) => { + match client_frame { + Ok(Some(payload)) => { + trace!(conn_id, bytes = payload.len(), "C->ME frame"); + stats.add_user_octets_from(&user, payload.len() as u64); + me_pool.send_proxy_req(conn_id, peer, translated_local_addr, &payload, proto_flags).await?; + } + Ok(None) => { + debug!(conn_id, "Client EOF"); + let _ = me_pool.send_close(conn_id).await; + break Ok(()); + } + Err(e) => break Err(e), + } + } + me_msg = me_rx.recv() => { + match me_msg { + Some(MeResponse::Data { flags, data }) => { + trace!(conn_id, bytes = data.len(), flags, "ME->C data"); + stats.add_user_octets_to(&user, data.len() as u64); + write_client_payload(&mut crypto_writer, proto_tag, flags, &data).await?; + } + Some(MeResponse::Ack(confirm)) => { + trace!(conn_id, confirm, "ME->C quickack"); + write_client_ack(&mut crypto_writer, proto_tag, confirm).await?; + } + Some(MeResponse::Close) => { + debug!(conn_id, "ME sent close"); + break Ok(()); + } + None => { + debug!(conn_id, "ME channel closed"); + break Err(ProxyError::Proxy("ME connection lost".into())); + } + } + } + } + }; + + debug!(user = %user, conn_id, "ME relay cleanup"); + me_pool.registry().unregister(conn_id).await; + stats.decrement_user_curr_connects(&user); + result +} + +async fn read_client_payload( + client_reader: &mut CryptoReader, + proto_tag: ProtoTag, +) -> Result>> +where + R: AsyncRead + Unpin + Send + 'static, +{ + let len = match proto_tag { + ProtoTag::Abridged => { + let mut first = [0u8; 1]; + match client_reader.read_exact(&mut first).await { + Ok(_) => {} + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), + Err(e) => return Err(ProxyError::Io(e)), + } + + let len_words = if (first[0] & 0x7f) == 0x7f { + let mut ext = [0u8; 3]; + client_reader + .read_exact(&mut ext) + .await + .map_err(ProxyError::Io)?; + u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize + } else { + (first[0] & 0x7f) as usize + }; + + len_words + .checked_mul(4) + .ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))? + } + ProtoTag::Intermediate | ProtoTag::Secure => { + let mut len_buf = [0u8; 4]; + match client_reader.read_exact(&mut len_buf).await { + Ok(_) => {} + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), + Err(e) => return Err(ProxyError::Io(e)), + } + (u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize + } + }; + + if len > 16 * 1024 * 1024 { + return Err(ProxyError::Proxy(format!("Frame too large: {len}"))); + } + + let mut payload = vec![0u8; len]; + client_reader + .read_exact(&mut payload) + .await + .map_err(ProxyError::Io)?; + Ok(Some(payload)) +} + +async fn write_client_payload( + client_writer: &mut CryptoWriter, + proto_tag: ProtoTag, + flags: u32, + data: &[u8], +) -> Result<()> +where + W: AsyncWrite + Unpin + Send + 'static, +{ + let quickack = (flags & RPC_FLAG_QUICKACK) != 0; + + match proto_tag { + ProtoTag::Abridged => { + if data.len() % 4 != 0 { + return Err(ProxyError::Proxy(format!( + "Abridged payload must be 4-byte aligned, got {}", + data.len() + ))); + } + + let len_words = data.len() / 4; + if len_words < 0x7f { + let mut first = len_words as u8; + if quickack { + first |= 0x80; + } + client_writer + .write_all(&[first]) + .await + .map_err(ProxyError::Io)?; + } else if len_words < (1 << 24) { + let mut first = 0x7fu8; + if quickack { + first |= 0x80; + } + let lw = (len_words as u32).to_le_bytes(); + client_writer + .write_all(&[first, lw[0], lw[1], lw[2]]) + .await + .map_err(ProxyError::Io)?; + } else { + return Err(ProxyError::Proxy(format!( + "Abridged frame too large: {}", + data.len() + ))); + } + + client_writer + .write_all(data) + .await + .map_err(ProxyError::Io)?; + } + ProtoTag::Intermediate | ProtoTag::Secure => { + let mut len = data.len() as u32; + if quickack { + len |= 0x8000_0000; + } + client_writer + .write_all(&len.to_le_bytes()) + .await + .map_err(ProxyError::Io)?; + client_writer + .write_all(data) + .await + .map_err(ProxyError::Io)?; + } + } + + client_writer.flush().await.map_err(ProxyError::Io) +} + +async fn write_client_ack( + client_writer: &mut CryptoWriter, + proto_tag: ProtoTag, + confirm: u32, +) -> Result<()> +where + W: AsyncWrite + Unpin + Send + 'static, +{ + let bytes = if proto_tag == ProtoTag::Abridged { + confirm.to_be_bytes() + } else { + confirm.to_le_bytes() + }; + client_writer + .write_all(&bytes) + .await + .map_err(ProxyError::Io)?; + client_writer.flush().await.map_err(ProxyError::Io) +} diff --git a/src/transport/middle_proxy/codec.rs b/src/transport/middle_proxy/codec.rs new file mode 100644 index 0000000..4eaaa4c --- /dev/null +++ b/src/transport/middle_proxy/codec.rs @@ -0,0 +1,178 @@ +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +use crate::crypto::{AesCbc, crc32}; +use crate::error::{ProxyError, Result}; +use crate::protocol::constants::*; + +pub(crate) fn build_rpc_frame(seq_no: i32, payload: &[u8]) -> Vec { + let total_len = (4 + 4 + payload.len() + 4) as u32; + let mut frame = Vec::with_capacity(total_len as usize); + frame.extend_from_slice(&total_len.to_le_bytes()); + frame.extend_from_slice(&seq_no.to_le_bytes()); + frame.extend_from_slice(payload); + let c = crc32(&frame); + frame.extend_from_slice(&c.to_le_bytes()); + frame +} + +pub(crate) async fn read_rpc_frame_plaintext( + rd: &mut (impl AsyncReadExt + Unpin), +) -> Result<(i32, Vec)> { + let mut len_buf = [0u8; 4]; + rd.read_exact(&mut len_buf).await.map_err(ProxyError::Io)?; + let total_len = u32::from_le_bytes(len_buf) as usize; + + if !(12..=(1 << 24)).contains(&total_len) { + return Err(ProxyError::InvalidHandshake(format!( + "Bad RPC frame length: {total_len}" + ))); + } + + let mut rest = vec![0u8; total_len - 4]; + rd.read_exact(&mut rest).await.map_err(ProxyError::Io)?; + + let mut full = Vec::with_capacity(total_len); + full.extend_from_slice(&len_buf); + full.extend_from_slice(&rest); + + let crc_offset = total_len - 4; + let expected_crc = u32::from_le_bytes(full[crc_offset..crc_offset + 4].try_into().unwrap()); + let actual_crc = crc32(&full[..crc_offset]); + if expected_crc != actual_crc { + return Err(ProxyError::InvalidHandshake(format!( + "CRC mismatch: 0x{expected_crc:08x} vs 0x{actual_crc:08x}" + ))); + } + + let seq_no = i32::from_le_bytes(full[4..8].try_into().unwrap()); + let payload = full[8..crc_offset].to_vec(); + Ok((seq_no, payload)) +} + +pub(crate) fn build_nonce_payload(key_selector: u32, crypto_ts: u32, nonce: &[u8; 16]) -> [u8; 32] { + let mut p = [0u8; 32]; + p[0..4].copy_from_slice(&RPC_NONCE_U32.to_le_bytes()); + p[4..8].copy_from_slice(&key_selector.to_le_bytes()); + p[8..12].copy_from_slice(&RPC_CRYPTO_AES_U32.to_le_bytes()); + p[12..16].copy_from_slice(&crypto_ts.to_le_bytes()); + p[16..32].copy_from_slice(nonce); + p +} + +pub(crate) fn parse_nonce_payload(d: &[u8]) -> Result<(u32, u32, [u8; 16])> { + if d.len() < 32 { + return Err(ProxyError::InvalidHandshake(format!( + "Nonce payload too short: {} bytes", + d.len() + ))); + } + + let t = u32::from_le_bytes(d[0..4].try_into().unwrap()); + if t != RPC_NONCE_U32 { + return Err(ProxyError::InvalidHandshake(format!( + "Expected RPC_NONCE 0x{RPC_NONCE_U32:08x}, got 0x{t:08x}" + ))); + } + + let schema = u32::from_le_bytes(d[8..12].try_into().unwrap()); + let ts = u32::from_le_bytes(d[12..16].try_into().unwrap()); + let mut nonce = [0u8; 16]; + nonce.copy_from_slice(&d[16..32]); + Ok((schema, ts, nonce)) +} + +pub(crate) fn build_handshake_payload( + our_ip: [u8; 4], + our_port: u16, + peer_ip: [u8; 4], + peer_port: u16, +) -> [u8; 32] { + let mut p = [0u8; 32]; + p[0..4].copy_from_slice(&RPC_HANDSHAKE_U32.to_le_bytes()); + + // Keep C memory layout compatibility for PID IPv4 bytes. + p[8..12].copy_from_slice(&our_ip); + p[12..14].copy_from_slice(&our_port.to_le_bytes()); + let pid = (std::process::id() & 0xffff) as u16; + p[14..16].copy_from_slice(&pid.to_le_bytes()); + let utime = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() as u32; + p[16..20].copy_from_slice(&utime.to_le_bytes()); + + p[20..24].copy_from_slice(&peer_ip); + p[24..26].copy_from_slice(&peer_port.to_le_bytes()); + p +} + +pub(crate) fn cbc_encrypt_padded( + key: &[u8; 32], + iv: &[u8; 16], + plaintext: &[u8], +) -> Result<(Vec, [u8; 16])> { + let pad = (16 - (plaintext.len() % 16)) % 16; + let mut buf = plaintext.to_vec(); + let pad_pattern: [u8; 4] = [0x04, 0x00, 0x00, 0x00]; + for i in 0..pad { + buf.push(pad_pattern[i % 4]); + } + + let cipher = AesCbc::new(*key, *iv); + cipher + .encrypt_in_place(&mut buf) + .map_err(|e| ProxyError::Crypto(format!("CBC encrypt: {e}")))?; + + let mut new_iv = [0u8; 16]; + if buf.len() >= 16 { + new_iv.copy_from_slice(&buf[buf.len() - 16..]); + } + Ok((buf, new_iv)) +} + +pub(crate) fn cbc_decrypt_inplace( + key: &[u8; 32], + iv: &[u8; 16], + data: &mut [u8], +) -> Result<[u8; 16]> { + let mut new_iv = [0u8; 16]; + if data.len() >= 16 { + new_iv.copy_from_slice(&data[data.len() - 16..]); + } + + AesCbc::new(*key, *iv) + .decrypt_in_place(data) + .map_err(|e| ProxyError::Crypto(format!("CBC decrypt: {e}")))?; + Ok(new_iv) +} + +pub(crate) struct RpcWriter { + pub(crate) writer: tokio::io::WriteHalf, + pub(crate) key: [u8; 32], + pub(crate) iv: [u8; 16], + pub(crate) seq_no: i32, +} + +impl RpcWriter { + pub(crate) async fn send(&mut self, payload: &[u8]) -> Result<()> { + let frame = build_rpc_frame(self.seq_no, payload); + self.seq_no += 1; + + let pad = (16 - (frame.len() % 16)) % 16; + let mut buf = frame; + let pad_pattern: [u8; 4] = [0x04, 0x00, 0x00, 0x00]; + for i in 0..pad { + buf.push(pad_pattern[i % 4]); + } + + let cipher = AesCbc::new(self.key, self.iv); + cipher + .encrypt_in_place(&mut buf) + .map_err(|e| ProxyError::Crypto(format!("{e}")))?; + + if buf.len() >= 16 { + self.iv.copy_from_slice(&buf[buf.len() - 16..]); + } + self.writer.write_all(&buf).await.map_err(ProxyError::Io) + } +} diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs new file mode 100644 index 0000000..d720c86 --- /dev/null +++ b/src/transport/middle_proxy/health.rs @@ -0,0 +1,38 @@ +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use tracing::{debug, info, warn}; + +use crate::crypto::SecureRandom; +use crate::protocol::constants::TG_MIDDLE_PROXIES_FLAT_V4; + +use super::MePool; + +pub async fn me_health_monitor(pool: Arc, rng: Arc, min_connections: usize) { + loop { + tokio::time::sleep(Duration::from_secs(30)).await; + let current = pool.connection_count(); + if current < min_connections { + warn!( + current, + min = min_connections, + "ME pool below minimum, reconnecting..." + ); + let addrs = TG_MIDDLE_PROXIES_FLAT_V4.clone(); + for &(ip, port) in addrs.iter() { + let needed = min_connections.saturating_sub(pool.connection_count()); + if needed == 0 { + break; + } + for _ in 0..needed { + let addr = SocketAddr::new(ip, port); + match pool.connect_one(addr, &rng).await { + Ok(()) => info!(%addr, "ME reconnected"), + Err(e) => debug!(%addr, error = %e, "ME reconnect failed"), + } + } + } + } + } +} diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs new file mode 100644 index 0000000..8aad640 --- /dev/null +++ b/src/transport/middle_proxy/mod.rs @@ -0,0 +1,24 @@ +//! Middle Proxy RPC transport. + +mod codec; +mod health; +mod pool; +mod reader; +mod registry; +mod secret; +mod wire; + +use bytes::Bytes; + +pub use health::me_health_monitor; +pub use pool::MePool; +pub use registry::ConnRegistry; +pub use secret::fetch_proxy_secret; +pub use wire::proto_flags_for_tag; + +#[derive(Debug)] +pub enum MeResponse { + Data { flags: u32, data: Bytes }, + Ack(u32), + Close, +} diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs new file mode 100644 index 0000000..f256aab --- /dev/null +++ b/src/transport/middle_proxy/pool.rs @@ -0,0 +1,431 @@ +use std::net::{IpAddr, SocketAddr}; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::Duration; + +use bytes::BytesMut; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; +use tokio::sync::{Mutex, RwLock}; +use tokio::time::{Instant, timeout}; +use tracing::{debug, info, warn}; + +use crate::crypto::{SecureRandom, derive_middleproxy_keys}; +use crate::error::{ProxyError, Result}; +use crate::protocol::constants::*; + +use super::ConnRegistry; +use super::codec::{ + RpcWriter, build_handshake_payload, build_nonce_payload, build_rpc_frame, cbc_decrypt_inplace, + cbc_encrypt_padded, parse_nonce_payload, read_rpc_frame_plaintext, +}; +use super::reader::reader_loop; +use super::wire::{IpMaterial, build_proxy_req_payload, extract_ip_material}; + +pub struct MePool { + registry: Arc, + writers: Arc>>>>, + rr: AtomicU64, + proxy_tag: Option>, + proxy_secret: Vec, + nat_ip: Option, + pool_size: usize, +} + +impl MePool { + pub fn new( + proxy_tag: Option>, + proxy_secret: Vec, + nat_ip: Option, + ) -> Arc { + Arc::new(Self { + registry: Arc::new(ConnRegistry::new()), + writers: Arc::new(RwLock::new(Vec::new())), + rr: AtomicU64::new(0), + proxy_tag, + proxy_secret, + nat_ip, + pool_size: 2, + }) + } + + pub fn has_proxy_tag(&self) -> bool { + self.proxy_tag.is_some() + } + + pub fn translate_our_addr(&self, addr: SocketAddr) -> SocketAddr { + let ip = self.translate_ip_for_nat(addr.ip()); + SocketAddr::new(ip, addr.port()) + } + + pub fn registry(&self) -> &Arc { + &self.registry + } + + fn translate_ip_for_nat(&self, ip: IpAddr) -> IpAddr { + let Some(nat_ip) = self.nat_ip else { + return ip; + }; + + match (ip, nat_ip) { + (IpAddr::V4(src), IpAddr::V4(dst)) + if src.is_private() || src.is_loopback() || src.is_unspecified() => + { + IpAddr::V4(dst) + } + (IpAddr::V6(src), IpAddr::V6(dst)) if src.is_loopback() || src.is_unspecified() => { + IpAddr::V6(dst) + } + (orig, _) => orig, + } + } + + fn writers_arc(&self) -> Arc>>>> { + self.writers.clone() + } + + fn key_selector(&self) -> u32 { + if self.proxy_secret.len() >= 4 { + u32::from_le_bytes([ + self.proxy_secret[0], + self.proxy_secret[1], + self.proxy_secret[2], + self.proxy_secret[3], + ]) + } else { + 0 + } + } + + pub async fn init(self: &Arc, pool_size: usize, rng: &SecureRandom) -> Result<()> { + let addrs = &*TG_MIDDLE_PROXIES_FLAT_V4; + let ks = self.key_selector(); + info!( + me_servers = addrs.len(), + pool_size, + key_selector = format_args!("0x{ks:08x}"), + secret_len = self.proxy_secret.len(), + "Initializing ME pool" + ); + + for &(ip, port) in addrs.iter() { + for i in 0..pool_size { + let addr = SocketAddr::new(ip, port); + match self.connect_one(addr, rng).await { + Ok(()) => info!(%addr, idx = i, "ME connected"), + Err(e) => warn!(%addr, idx = i, error = %e, "ME connect failed"), + } + } + if self.writers.read().await.len() >= pool_size { + break; + } + } + + if self.writers.read().await.is_empty() { + return Err(ProxyError::Proxy("No ME connections".into())); + } + Ok(()) + } + + pub(crate) async fn connect_one( + self: &Arc, + addr: SocketAddr, + rng: &SecureRandom, + ) -> Result<()> { + let secret = &self.proxy_secret; + if secret.len() < 32 { + return Err(ProxyError::Proxy( + "proxy-secret too short for ME auth".into(), + )); + } + + let stream = timeout( + Duration::from_secs(ME_CONNECT_TIMEOUT_SECS), + TcpStream::connect(addr), + ) + .await + .map_err(|_| ProxyError::ConnectionTimeout { + addr: addr.to_string(), + })? + .map_err(ProxyError::Io)?; + stream.set_nodelay(true).ok(); + + let local_addr = stream.local_addr().map_err(ProxyError::Io)?; + let peer_addr = stream.peer_addr().map_err(ProxyError::Io)?; + let local_addr_nat = self.translate_our_addr(local_addr); + let peer_addr_nat = + SocketAddr::new(self.translate_ip_for_nat(peer_addr.ip()), peer_addr.port()); + let (mut rd, mut wr) = tokio::io::split(stream); + + let my_nonce: [u8; 16] = rng.bytes(16).try_into().unwrap(); + let crypto_ts = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() as u32; + + let ks = self.key_selector(); + let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce); + let nonce_frame = build_rpc_frame(-2, &nonce_payload); + wr.write_all(&nonce_frame).await.map_err(ProxyError::Io)?; + wr.flush().await.map_err(ProxyError::Io)?; + + let (srv_seq, srv_nonce_payload) = timeout( + Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS), + read_rpc_frame_plaintext(&mut rd), + ) + .await + .map_err(|_| ProxyError::TgHandshakeTimeout)??; + + if srv_seq != -2 { + return Err(ProxyError::InvalidHandshake(format!( + "Expected seq=-2, got {srv_seq}" + ))); + } + + let (schema, srv_ts, srv_nonce) = parse_nonce_payload(&srv_nonce_payload)?; + if schema != RPC_CRYPTO_AES_U32 { + return Err(ProxyError::InvalidHandshake(format!( + "Unsupported crypto schema: 0x{schema:x}" + ))); + } + + let skew = crypto_ts.abs_diff(srv_ts); + if skew > 30 { + return Err(ProxyError::InvalidHandshake(format!( + "nonce crypto_ts skew too large: client={crypto_ts}, server={srv_ts}, skew={skew}s" + ))); + } + + let ts_bytes = crypto_ts.to_le_bytes(); + let server_port_bytes = peer_addr_nat.port().to_le_bytes(); + let client_port_bytes = local_addr_nat.port().to_le_bytes(); + + let server_ip = extract_ip_material(peer_addr_nat); + let client_ip = extract_ip_material(local_addr_nat); + + let (srv_ip_opt, clt_ip_opt, clt_v6_opt, srv_v6_opt, hs_our_ip, hs_peer_ip) = + match (server_ip, client_ip) { + (IpMaterial::V4(srv), IpMaterial::V4(clt)) => { + (Some(srv), Some(clt), None, None, clt, srv) + } + (IpMaterial::V6(srv), IpMaterial::V6(clt)) => { + let zero = [0u8; 4]; + (None, None, Some(clt), Some(srv), zero, zero) + } + _ => { + return Err(ProxyError::InvalidHandshake( + "mixed IPv4/IPv6 endpoints are not supported for ME key derivation" + .to_string(), + )); + } + }; + + let (wk, wi) = derive_middleproxy_keys( + &srv_nonce, + &my_nonce, + &ts_bytes, + srv_ip_opt.as_ref().map(|x| &x[..]), + &client_port_bytes, + b"CLIENT", + clt_ip_opt.as_ref().map(|x| &x[..]), + &server_port_bytes, + secret, + clt_v6_opt.as_ref(), + srv_v6_opt.as_ref(), + ); + let (rk, ri) = derive_middleproxy_keys( + &srv_nonce, + &my_nonce, + &ts_bytes, + srv_ip_opt.as_ref().map(|x| &x[..]), + &client_port_bytes, + b"SERVER", + clt_ip_opt.as_ref().map(|x| &x[..]), + &server_port_bytes, + secret, + clt_v6_opt.as_ref(), + srv_v6_opt.as_ref(), + ); + + let hs_payload = + build_handshake_payload(hs_our_ip, local_addr.port(), hs_peer_ip, peer_addr.port()); + let hs_frame = build_rpc_frame(-1, &hs_payload); + let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?; + wr.write_all(&encrypted_hs).await.map_err(ProxyError::Io)?; + wr.flush().await.map_err(ProxyError::Io)?; + + let deadline = Instant::now() + Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS); + let mut enc_buf = BytesMut::with_capacity(256); + let mut dec_buf = BytesMut::with_capacity(256); + let mut read_iv = ri; + let mut handshake_ok = false; + + while Instant::now() < deadline && !handshake_ok { + let remaining = deadline - Instant::now(); + let mut tmp = [0u8; 256]; + let n = match timeout(remaining, rd.read(&mut tmp)).await { + Ok(Ok(0)) => { + return Err(ProxyError::Io(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "ME closed during handshake", + ))); + } + Ok(Ok(n)) => n, + Ok(Err(e)) => return Err(ProxyError::Io(e)), + Err(_) => return Err(ProxyError::TgHandshakeTimeout), + }; + + enc_buf.extend_from_slice(&tmp[..n]); + + let blocks = enc_buf.len() / 16 * 16; + if blocks > 0 { + let mut chunk = vec![0u8; blocks]; + chunk.copy_from_slice(&enc_buf[..blocks]); + read_iv = cbc_decrypt_inplace(&rk, &read_iv, &mut chunk)?; + dec_buf.extend_from_slice(&chunk); + let _ = enc_buf.split_to(blocks); + } + + while dec_buf.len() >= 4 { + let fl = u32::from_le_bytes(dec_buf[0..4].try_into().unwrap()) as usize; + + if fl == 4 { + let _ = dec_buf.split_to(4); + continue; + } + if !(12..=(1 << 24)).contains(&fl) { + return Err(ProxyError::InvalidHandshake(format!( + "Bad HS response frame len: {fl}" + ))); + } + if dec_buf.len() < fl { + break; + } + + let frame = dec_buf.split_to(fl); + let pe = fl - 4; + let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); + let ac = crate::crypto::crc32(&frame[..pe]); + if ec != ac { + return Err(ProxyError::InvalidHandshake(format!( + "HS CRC mismatch: 0x{ec:08x} vs 0x{ac:08x}" + ))); + } + + let hs_type = u32::from_le_bytes(frame[8..12].try_into().unwrap()); + if hs_type == RPC_HANDSHAKE_ERROR_U32 { + let err_code = if frame.len() >= 16 { + i32::from_le_bytes(frame[12..16].try_into().unwrap()) + } else { + -1 + }; + return Err(ProxyError::InvalidHandshake(format!( + "ME rejected handshake (error={err_code})" + ))); + } + if hs_type != RPC_HANDSHAKE_U32 { + return Err(ProxyError::InvalidHandshake(format!( + "Expected HANDSHAKE 0x{RPC_HANDSHAKE_U32:08x}, got 0x{hs_type:08x}" + ))); + } + + handshake_ok = true; + break; + } + } + + if !handshake_ok { + return Err(ProxyError::TgHandshakeTimeout); + } + + info!(%addr, "RPC handshake OK"); + + let rpc_w = Arc::new(Mutex::new(RpcWriter { + writer: wr, + key: wk, + iv: write_iv, + seq_no: 0, + })); + self.writers.write().await.push(rpc_w.clone()); + + let reg = self.registry.clone(); + let w_pong = rpc_w.clone(); + let w_pool = self.writers_arc(); + tokio::spawn(async move { + if let Err(e) = + reader_loop(rd, rk, read_iv, reg, enc_buf, dec_buf, w_pong.clone()).await + { + warn!(error = %e, "ME reader ended"); + } + let mut ws = w_pool.write().await; + ws.retain(|w| !Arc::ptr_eq(w, &w_pong)); + info!(remaining = ws.len(), "Dead ME writer removed from pool"); + }); + + Ok(()) + } + + pub async fn send_proxy_req( + &self, + conn_id: u64, + client_addr: SocketAddr, + our_addr: SocketAddr, + data: &[u8], + proto_flags: u32, + ) -> Result<()> { + let payload = build_proxy_req_payload( + conn_id, + client_addr, + our_addr, + data, + self.proxy_tag.as_deref(), + proto_flags, + ); + + loop { + let ws = self.writers.read().await; + if ws.is_empty() { + return Err(ProxyError::Proxy("All ME connections dead".into())); + } + + let idx = self.rr.fetch_add(1, Ordering::Relaxed) as usize % ws.len(); + let w = ws[idx].clone(); + drop(ws); + + match w.lock().await.send(&payload).await { + Ok(()) => return Ok(()), + Err(e) => { + warn!(error = %e, "ME write failed, removing dead conn"); + let mut ws = self.writers.write().await; + ws.retain(|o| !Arc::ptr_eq(o, &w)); + if ws.is_empty() { + return Err(ProxyError::Proxy("All ME connections dead".into())); + } + } + } + } + } + + pub async fn send_close(&self, conn_id: u64) -> Result<()> { + let ws = self.writers.read().await; + if !ws.is_empty() { + let w = ws[0].clone(); + drop(ws); + let mut p = Vec::with_capacity(12); + p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes()); + p.extend_from_slice(&conn_id.to_le_bytes()); + if let Err(e) = w.lock().await.send(&p).await { + debug!(error = %e, "ME close write failed"); + let mut ws = self.writers.write().await; + ws.retain(|o| !Arc::ptr_eq(o, &w)); + } + } + + self.registry.unregister(conn_id).await; + Ok(()) + } + + pub fn connection_count(&self) -> usize { + self.writers.try_read().map(|w| w.len()).unwrap_or(0) + } +} diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs new file mode 100644 index 0000000..83df742 --- /dev/null +++ b/src/transport/middle_proxy/reader.rs @@ -0,0 +1,141 @@ +use std::sync::Arc; + +use bytes::{Bytes, BytesMut}; +use tokio::io::AsyncReadExt; +use tokio::net::TcpStream; +use tokio::sync::Mutex; +use tracing::{debug, trace, warn}; + +use crate::crypto::{AesCbc, crc32}; +use crate::error::{ProxyError, Result}; +use crate::protocol::constants::*; + +use super::codec::RpcWriter; +use super::{ConnRegistry, MeResponse}; + +pub(crate) async fn reader_loop( + mut rd: tokio::io::ReadHalf, + dk: [u8; 32], + mut div: [u8; 16], + reg: Arc, + enc_leftover: BytesMut, + mut dec: BytesMut, + writer: Arc>, +) -> Result<()> { + let mut raw = enc_leftover; + + loop { + let mut tmp = [0u8; 16_384]; + let n = rd.read(&mut tmp).await.map_err(ProxyError::Io)?; + if n == 0 { + return Ok(()); + } + raw.extend_from_slice(&tmp[..n]); + + let blocks = raw.len() / 16 * 16; + if blocks > 0 { + let mut new_iv = [0u8; 16]; + new_iv.copy_from_slice(&raw[blocks - 16..blocks]); + + let mut chunk = vec![0u8; blocks]; + chunk.copy_from_slice(&raw[..blocks]); + AesCbc::new(dk, div) + .decrypt_in_place(&mut chunk) + .map_err(|e| ProxyError::Crypto(format!("{e}")))?; + div = new_iv; + dec.extend_from_slice(&chunk); + let _ = raw.split_to(blocks); + } + + while dec.len() >= 12 { + let fl = u32::from_le_bytes(dec[0..4].try_into().unwrap()) as usize; + if fl == 4 { + let _ = dec.split_to(4); + continue; + } + if !(12..=(1 << 24)).contains(&fl) { + warn!(frame_len = fl, "Invalid RPC frame len"); + dec.clear(); + break; + } + if dec.len() < fl { + break; + } + + let frame = dec.split_to(fl); + let pe = fl - 4; + let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); + if crc32(&frame[..pe]) != ec { + warn!("CRC mismatch in data frame"); + continue; + } + + let payload = &frame[8..pe]; + if payload.len() < 4 { + continue; + } + + let pt = u32::from_le_bytes(payload[0..4].try_into().unwrap()); + let body = &payload[4..]; + + if pt == RPC_PROXY_ANS_U32 && body.len() >= 12 { + let flags = u32::from_le_bytes(body[0..4].try_into().unwrap()); + let cid = u64::from_le_bytes(body[4..12].try_into().unwrap()); + let data = Bytes::copy_from_slice(&body[12..]); + trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS"); + + let routed = reg.route(cid, MeResponse::Data { flags, data }).await; + if !routed { + reg.unregister(cid).await; + send_close_conn(&writer, cid).await; + } + } else if pt == RPC_SIMPLE_ACK_U32 && body.len() >= 12 { + let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); + let cfm = u32::from_le_bytes(body[8..12].try_into().unwrap()); + trace!(cid, cfm, "RPC_SIMPLE_ACK"); + + let routed = reg.route(cid, MeResponse::Ack(cfm)).await; + if !routed { + reg.unregister(cid).await; + send_close_conn(&writer, cid).await; + } + } else if pt == RPC_CLOSE_EXT_U32 && body.len() >= 8 { + let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); + debug!(cid, "RPC_CLOSE_EXT from ME"); + reg.route(cid, MeResponse::Close).await; + reg.unregister(cid).await; + } else if pt == RPC_CLOSE_CONN_U32 && body.len() >= 8 { + let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); + debug!(cid, "RPC_CLOSE_CONN from ME"); + reg.route(cid, MeResponse::Close).await; + reg.unregister(cid).await; + } else if pt == RPC_PING_U32 && body.len() >= 8 { + let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap()); + trace!(ping_id, "RPC_PING -> RPC_PONG"); + let mut pong = Vec::with_capacity(12); + pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes()); + pong.extend_from_slice(&ping_id.to_le_bytes()); + if let Err(e) = writer.lock().await.send(&pong).await { + warn!(error = %e, "PONG send failed"); + break; + } + } else { + debug!( + rpc_type = format_args!("0x{pt:08x}"), + len = body.len(), + "Unknown RPC" + ); + } + } + } +} + +async fn send_close_conn(writer: &Arc>, conn_id: u64) { + let mut p = Vec::with_capacity(12); + p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes()); + p.extend_from_slice(&conn_id.to_le_bytes()); + + if let Err(e) = writer.lock().await.send(&p).await { + debug!(conn_id, error = %e, "Failed to send RPC_CLOSE_CONN"); + } +} diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs new file mode 100644 index 0000000..40c2803 --- /dev/null +++ b/src/transport/middle_proxy/registry.rs @@ -0,0 +1,40 @@ +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; + +use tokio::sync::{RwLock, mpsc}; + +use super::MeResponse; + +pub struct ConnRegistry { + map: RwLock>>, + next_id: AtomicU64, +} + +impl ConnRegistry { + pub fn new() -> Self { + Self { + map: RwLock::new(HashMap::new()), + next_id: AtomicU64::new(1), + } + } + + pub async fn register(&self) -> (u64, mpsc::Receiver) { + let id = self.next_id.fetch_add(1, Ordering::Relaxed); + let (tx, rx) = mpsc::channel(256); + self.map.write().await.insert(id, tx); + (id, rx) + } + + pub async fn unregister(&self, id: u64) { + self.map.write().await.remove(&id); + } + + pub async fn route(&self, id: u64, resp: MeResponse) -> bool { + let m = self.map.read().await; + if let Some(tx) = m.get(&id) { + tx.send(resp).await.is_ok() + } else { + false + } + } +} diff --git a/src/transport/middle_proxy/secret.rs b/src/transport/middle_proxy/secret.rs new file mode 100644 index 0000000..5b201fd --- /dev/null +++ b/src/transport/middle_proxy/secret.rs @@ -0,0 +1,76 @@ +use std::time::Duration; + +use tracing::{debug, info, warn}; + +use crate::error::{ProxyError, Result}; + +/// Fetch Telegram proxy-secret binary. +pub async fn fetch_proxy_secret(cache_path: Option<&str>) -> Result> { + let cache = cache_path.unwrap_or("proxy-secret"); + + if let Ok(metadata) = tokio::fs::metadata(cache).await { + if let Ok(modified) = metadata.modified() { + let age = std::time::SystemTime::now() + .duration_since(modified) + .unwrap_or(Duration::from_secs(u64::MAX)); + if age < Duration::from_secs(86_400) { + if let Ok(data) = tokio::fs::read(cache).await { + if data.len() >= 32 { + info!( + path = cache, + len = data.len(), + age_hours = age.as_secs() / 3600, + "Loaded proxy-secret from cache" + ); + return Ok(data); + } + warn!( + path = cache, + len = data.len(), + "Cached proxy-secret too short" + ); + } + } + } + } + + info!("Downloading proxy-secret from core.telegram.org..."); + let data = download_proxy_secret().await?; + + if let Err(e) = tokio::fs::write(cache, &data).await { + warn!(error = %e, "Failed to cache proxy-secret (non-fatal)"); + } else { + debug!(path = cache, len = data.len(), "Cached proxy-secret"); + } + + Ok(data) +} + +async fn download_proxy_secret() -> Result> { + let resp = reqwest::get("https://core.telegram.org/getProxySecret") + .await + .map_err(|e| ProxyError::Proxy(format!("Failed to download proxy-secret: {e}")))?; + + if !resp.status().is_success() { + return Err(ProxyError::Proxy(format!( + "proxy-secret download HTTP {}", + resp.status() + ))); + } + + let data = resp + .bytes() + .await + .map_err(|e| ProxyError::Proxy(format!("Read proxy-secret body: {e}")))? + .to_vec(); + + if data.len() < 32 { + return Err(ProxyError::Proxy(format!( + "proxy-secret too short: {} bytes (need >= 32)", + data.len() + ))); + } + + info!(len = data.len(), "Downloaded proxy-secret OK"); + Ok(data) +} diff --git a/src/transport/middle_proxy/wire.rs b/src/transport/middle_proxy/wire.rs new file mode 100644 index 0000000..1ed9727 --- /dev/null +++ b/src/transport/middle_proxy/wire.rs @@ -0,0 +1,106 @@ +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + +use crate::protocol::constants::*; + +#[derive(Clone, Copy)] +pub(crate) enum IpMaterial { + V4([u8; 4]), + V6([u8; 16]), +} + +pub(crate) fn extract_ip_material(addr: SocketAddr) -> IpMaterial { + match addr.ip() { + IpAddr::V4(v4) => IpMaterial::V4(v4.octets()), + IpAddr::V6(v6) => { + if let Some(v4) = v6.to_ipv4_mapped() { + IpMaterial::V4(v4.octets()) + } else { + IpMaterial::V6(v6.octets()) + } + } + } +} + +fn ipv4_to_mapped_v6_c_compat(ip: Ipv4Addr) -> [u8; 16] { + let mut buf = [0u8; 16]; + + // Matches tl_store_long(0) + tl_store_int(-0x10000). + buf[8..12].copy_from_slice(&(-0x10000i32).to_le_bytes()); + + // Matches tl_store_int(htonl(remote_ip_host_order)). + let host_order = u32::from_ne_bytes(ip.octets()); + let network_order = host_order.to_be(); + buf[12..16].copy_from_slice(&network_order.to_le_bytes()); + + buf +} + +fn append_mapped_addr_and_port(buf: &mut Vec, addr: SocketAddr) { + match addr.ip() { + IpAddr::V4(v4) => buf.extend_from_slice(&ipv4_to_mapped_v6_c_compat(v4)), + IpAddr::V6(v6) => buf.extend_from_slice(&v6.octets()), + } + buf.extend_from_slice(&(addr.port() as u32).to_le_bytes()); +} + +pub(crate) fn build_proxy_req_payload( + conn_id: u64, + client_addr: SocketAddr, + our_addr: SocketAddr, + data: &[u8], + proxy_tag: Option<&[u8]>, + proto_flags: u32, +) -> Vec { + let mut b = Vec::with_capacity(128 + data.len()); + + b.extend_from_slice(&RPC_PROXY_REQ_U32.to_le_bytes()); + b.extend_from_slice(&proto_flags.to_le_bytes()); + b.extend_from_slice(&conn_id.to_le_bytes()); + + append_mapped_addr_and_port(&mut b, client_addr); + append_mapped_addr_and_port(&mut b, our_addr); + + if proto_flags & 12 != 0 { + let extra_start = b.len(); + b.extend_from_slice(&0u32.to_le_bytes()); + + if let Some(tag) = proxy_tag { + b.extend_from_slice(&TL_PROXY_TAG_U32.to_le_bytes()); + + if tag.len() < 254 { + b.push(tag.len() as u8); + b.extend_from_slice(tag); + let pad = (4 - ((1 + tag.len()) % 4)) % 4; + b.extend(std::iter::repeat_n(0u8, pad)); + } else { + b.push(0xfe); + let len_bytes = (tag.len() as u32).to_le_bytes(); + b.extend_from_slice(&len_bytes[..3]); + b.extend_from_slice(tag); + let pad = (4 - (tag.len() % 4)) % 4; + b.extend(std::iter::repeat_n(0u8, pad)); + } + } + + let extra_bytes = (b.len() - extra_start - 4) as u32; + b[extra_start..extra_start + 4].copy_from_slice(&extra_bytes.to_le_bytes()); + } + + b.extend_from_slice(data); + b +} + +pub fn proto_flags_for_tag(tag: crate::protocol::constants::ProtoTag, has_proxy_tag: bool) -> u32 { + use crate::protocol::constants::ProtoTag; + + let mut flags = RPC_FLAG_MAGIC | RPC_FLAG_EXTMODE2; + if has_proxy_tag { + flags |= RPC_FLAG_HAS_AD_TAG; + } + + match tag { + ProtoTag::Abridged => flags | RPC_FLAG_ABRIDGED, + ProtoTag::Intermediate => flags | RPC_FLAG_INTERMEDIATE, + ProtoTag::Secure => flags | RPC_FLAG_PAD | RPC_FLAG_INTERMEDIATE, + } +} diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 10becae..51cffa4 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -10,5 +10,5 @@ pub use pool::ConnectionPool; pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol}; pub use socket::*; pub use socks::*; -pub use upstream::{UpstreamManager, StartupPingResult, DcPingResult}; +pub use upstream::{DcPingResult, StartupPingResult, UpstreamManager}; pub mod middle_proxy;