Middle Proxy is so real

This commit is contained in:
Alexey
2026-02-14 01:36:14 +03:00
parent 9b850b0bfb
commit 70859aa5cf
14 changed files with 2028 additions and 785 deletions

View File

@@ -1,32 +1,55 @@
//! Configuration //! Configuration
use crate::error::{ProxyError, Result};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::path::Path; use std::path::Path;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::error::{ProxyError, Result};
// ============= Helper Defaults ============= // ============= Helper Defaults =============
fn default_true() -> bool { true } fn default_true() -> bool {
fn default_port() -> u16 { 443 } true
fn default_tls_domain() -> String { "www.google.com".to_string() } }
fn default_mask_port() -> u16 { 443 } fn default_port() -> u16 {
fn default_replay_check_len() -> usize { 65536 } 443
fn default_replay_window_secs() -> u64 { 1800 } }
fn default_handshake_timeout() -> u64 { 15 } fn default_tls_domain() -> String {
fn default_connect_timeout() -> u64 { 10 } "www.google.com".to_string()
fn default_keepalive() -> u64 { 60 } }
fn default_ack_timeout() -> u64 { 300 } fn default_mask_port() -> u16 {
fn default_listen_addr() -> String { "0.0.0.0".to_string() } 443
fn default_fake_cert_len() -> usize { 2048 } }
fn default_weight() -> u16 { 1 } fn default_replay_check_len() -> usize {
65536
}
fn default_replay_window_secs() -> u64 {
1800
}
fn default_handshake_timeout() -> u64 {
15
}
fn default_connect_timeout() -> u64 {
10
}
fn default_keepalive() -> u64 {
60
}
fn default_ack_timeout() -> u64 {
300
}
fn default_listen_addr() -> String {
"0.0.0.0".to_string()
}
fn default_fake_cert_len() -> usize {
2048
}
fn default_weight() -> u16 {
1
}
fn default_metrics_whitelist() -> Vec<IpAddr> { fn default_metrics_whitelist() -> Vec<IpAddr> {
vec![ vec!["127.0.0.1".parse().unwrap(), "::1".parse().unwrap()]
"127.0.0.1".parse().unwrap(),
"::1".parse().unwrap(),
]
} }
// ============= Log Level ============= // ============= Log Level =============
@@ -96,7 +119,11 @@ pub struct ProxyModes {
impl Default for ProxyModes { impl Default for ProxyModes {
fn default() -> Self { fn default() -> Self {
Self { classic: true, secure: true, tls: true } Self {
classic: true,
secure: true,
tls: true,
}
} }
} }
@@ -104,13 +131,13 @@ impl Default for ProxyModes {
pub struct GeneralConfig { pub struct GeneralConfig {
#[serde(default)] #[serde(default)]
pub modes: ProxyModes, pub modes: ProxyModes,
#[serde(default)] #[serde(default)]
pub prefer_ipv6: bool, pub prefer_ipv6: bool,
#[serde(default = "default_true")] #[serde(default = "default_true")]
pub fast_mode: bool, pub fast_mode: bool,
#[serde(default)] #[serde(default)]
pub use_middle_proxy: bool, pub use_middle_proxy: bool,
@@ -121,7 +148,12 @@ pub struct GeneralConfig {
/// Infrastructure secret from https://core.telegram.org/getProxySecret /// Infrastructure secret from https://core.telegram.org/getProxySecret
#[serde(default)] #[serde(default)]
pub proxy_secret_path: Option<String>, pub proxy_secret_path: Option<String>,
/// Public IP override for middle-proxy NAT environments.
/// When set, this IP is used in ME key derivation and RPC_PROXY_REQ "our_addr".
#[serde(default)]
pub middle_proxy_nat_ip: Option<IpAddr>,
#[serde(default)] #[serde(default)]
pub log_level: LogLevel, pub log_level: LogLevel,
} }
@@ -135,6 +167,7 @@ impl Default for GeneralConfig {
use_middle_proxy: false, use_middle_proxy: false,
ad_tag: None, ad_tag: None,
proxy_secret_path: None, proxy_secret_path: None,
middle_proxy_nat_ip: None,
log_level: LogLevel::Normal, log_level: LogLevel::Normal,
} }
} }
@@ -147,16 +180,16 @@ pub struct ServerConfig {
#[serde(default = "default_listen_addr")] #[serde(default = "default_listen_addr")]
pub listen_addr_ipv4: String, pub listen_addr_ipv4: String,
#[serde(default)] #[serde(default)]
pub listen_addr_ipv6: Option<String>, pub listen_addr_ipv6: Option<String>,
#[serde(default)] #[serde(default)]
pub listen_unix_sock: Option<String>, pub listen_unix_sock: Option<String>,
#[serde(default)] #[serde(default)]
pub metrics_port: Option<u16>, pub metrics_port: Option<u16>,
#[serde(default = "default_metrics_whitelist")] #[serde(default = "default_metrics_whitelist")]
pub metrics_whitelist: Vec<IpAddr>, pub metrics_whitelist: Vec<IpAddr>,
@@ -182,13 +215,13 @@ impl Default for ServerConfig {
pub struct TimeoutsConfig { pub struct TimeoutsConfig {
#[serde(default = "default_handshake_timeout")] #[serde(default = "default_handshake_timeout")]
pub client_handshake: u64, pub client_handshake: u64,
#[serde(default = "default_connect_timeout")] #[serde(default = "default_connect_timeout")]
pub tg_connect: u64, pub tg_connect: u64,
#[serde(default = "default_keepalive")] #[serde(default = "default_keepalive")]
pub client_keepalive: u64, pub client_keepalive: u64,
#[serde(default = "default_ack_timeout")] #[serde(default = "default_ack_timeout")]
pub client_ack: u64, pub client_ack: u64,
} }
@@ -208,13 +241,13 @@ impl Default for TimeoutsConfig {
pub struct AntiCensorshipConfig { pub struct AntiCensorshipConfig {
#[serde(default = "default_tls_domain")] #[serde(default = "default_tls_domain")]
pub tls_domain: String, pub tls_domain: String,
#[serde(default = "default_true")] #[serde(default = "default_true")]
pub mask: bool, pub mask: bool,
#[serde(default)] #[serde(default)]
pub mask_host: Option<String>, pub mask_host: Option<String>,
#[serde(default = "default_mask_port")] #[serde(default = "default_mask_port")]
pub mask_port: u16, pub mask_port: u16,
@@ -245,19 +278,19 @@ pub struct AccessConfig {
#[serde(default)] #[serde(default)]
pub user_max_tcp_conns: HashMap<String, usize>, pub user_max_tcp_conns: HashMap<String, usize>,
#[serde(default)] #[serde(default)]
pub user_expirations: HashMap<String, DateTime<Utc>>, pub user_expirations: HashMap<String, DateTime<Utc>>,
#[serde(default)] #[serde(default)]
pub user_data_quota: HashMap<String, u64>, pub user_data_quota: HashMap<String, u64>,
#[serde(default = "default_replay_check_len")] #[serde(default = "default_replay_check_len")]
pub replay_check_len: usize, pub replay_check_len: usize,
#[serde(default = "default_replay_window_secs")] #[serde(default = "default_replay_window_secs")]
pub replay_window_secs: u64, pub replay_window_secs: u64,
#[serde(default)] #[serde(default)]
pub ignore_time_skew: bool, pub ignore_time_skew: bool,
} }
@@ -265,7 +298,10 @@ pub struct AccessConfig {
impl Default for AccessConfig { impl Default for AccessConfig {
fn default() -> Self { fn default() -> Self {
let mut users = HashMap::new(); let mut users = HashMap::new();
users.insert("default".to_string(), "00000000000000000000000000000000".to_string()); users.insert(
"default".to_string(),
"00000000000000000000000000000000".to_string(),
);
Self { Self {
users, users,
user_max_tcp_conns: HashMap::new(), user_max_tcp_conns: HashMap::new(),
@@ -365,12 +401,12 @@ pub struct ProxyConfig {
impl ProxyConfig { impl ProxyConfig {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> { pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let content = std::fs::read_to_string(path) let content =
.map_err(|e| ProxyError::Config(e.to_string()))?; std::fs::read_to_string(path).map_err(|e| ProxyError::Config(e.to_string()))?;
let mut config: ProxyConfig = toml::from_str(&content) let mut config: ProxyConfig =
.map_err(|e| ProxyError::Config(e.to_string()))?; toml::from_str(&content).map_err(|e| ProxyError::Config(e.to_string()))?;
// Validate secrets // Validate secrets
for (user, secret) in &config.access.users { for (user, secret) in &config.access.users {
if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 { if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 {
@@ -380,33 +416,34 @@ impl ProxyConfig {
}); });
} }
} }
// Validate tls_domain // Validate tls_domain
if config.censorship.tls_domain.is_empty() { if config.censorship.tls_domain.is_empty() {
return Err(ProxyError::Config("tls_domain cannot be empty".to_string())); return Err(ProxyError::Config("tls_domain cannot be empty".to_string()));
} }
// Validate mask_unix_sock // Validate mask_unix_sock
if let Some(ref sock_path) = config.censorship.mask_unix_sock { if let Some(ref sock_path) = config.censorship.mask_unix_sock {
if sock_path.is_empty() { if sock_path.is_empty() {
return Err(ProxyError::Config( return Err(ProxyError::Config(
"mask_unix_sock cannot be empty".to_string() "mask_unix_sock cannot be empty".to_string(),
)); ));
} }
#[cfg(unix)] #[cfg(unix)]
if sock_path.len() > 107 { if sock_path.len() > 107 {
return Err(ProxyError::Config( return Err(ProxyError::Config(format!(
format!("mask_unix_sock path too long: {} bytes (max 107)", sock_path.len()) "mask_unix_sock path too long: {} bytes (max 107)",
)); sock_path.len()
)));
} }
#[cfg(not(unix))] #[cfg(not(unix))]
return Err(ProxyError::Config( return Err(ProxyError::Config(
"mask_unix_sock is only supported on Unix platforms".to_string() "mask_unix_sock is only supported on Unix platforms".to_string(),
)); ));
if config.censorship.mask_host.is_some() { if config.censorship.mask_host.is_some() {
return Err(ProxyError::Config( return Err(ProxyError::Config(
"mask_unix_sock and mask_host are mutually exclusive".to_string() "mask_unix_sock and mask_host are mutually exclusive".to_string(),
)); ));
} }
} }
@@ -415,11 +452,11 @@ impl ProxyConfig {
if config.censorship.mask_host.is_none() && config.censorship.mask_unix_sock.is_none() { if config.censorship.mask_host.is_none() && config.censorship.mask_unix_sock.is_none() {
config.censorship.mask_host = Some(config.censorship.tls_domain.clone()); config.censorship.mask_host = Some(config.censorship.tls_domain.clone());
} }
// Random fake_cert_len // Random fake_cert_len
use rand::Rng; use rand::Rng;
config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096); config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096);
// Migration: Populate listeners if empty // Migration: Populate listeners if empty
if config.server.listeners.is_empty() { if config.server.listeners.is_empty() {
if let Ok(ipv4) = config.server.listen_addr_ipv4.parse::<IpAddr>() { if let Ok(ipv4) = config.server.listen_addr_ipv4.parse::<IpAddr>() {
@@ -429,7 +466,7 @@ impl ProxyConfig {
}); });
} }
if let Some(ipv6_str) = &config.server.listen_addr_ipv6 { if let Some(ipv6_str) = &config.server.listen_addr_ipv6 {
if let Ok(ipv6) = ipv6_str.parse::<IpAddr>() { if let Ok(ipv6) = ipv6_str.parse::<IpAddr>() {
config.server.listeners.push(ListenerConfig { config.server.listeners.push(ListenerConfig {
ip: ipv6, ip: ipv6,
announce_ip: None, announce_ip: None,
@@ -440,31 +477,32 @@ impl ProxyConfig {
// Migration: Populate upstreams if empty (Default Direct) // Migration: Populate upstreams if empty (Default Direct)
if config.upstreams.is_empty() { if config.upstreams.is_empty() {
config.upstreams.push(UpstreamConfig { config.upstreams.push(UpstreamConfig {
upstream_type: UpstreamType::Direct { interface: None }, upstream_type: UpstreamType::Direct { interface: None },
weight: 1, weight: 1,
enabled: true, enabled: true,
}); });
} }
Ok(config) Ok(config)
} }
pub fn validate(&self) -> Result<()> { pub fn validate(&self) -> Result<()> {
if self.access.users.is_empty() { if self.access.users.is_empty() {
return Err(ProxyError::Config("No users configured".to_string())); return Err(ProxyError::Config("No users configured".to_string()));
} }
if !self.general.modes.classic && !self.general.modes.secure && !self.general.modes.tls { if !self.general.modes.classic && !self.general.modes.secure && !self.general.modes.tls {
return Err(ProxyError::Config("No modes enabled".to_string())); return Err(ProxyError::Config("No modes enabled".to_string()));
} }
if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') { if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') {
return Err(ProxyError::Config( return Err(ProxyError::Config(format!(
format!("Invalid tls_domain: '{}'. Must be a valid domain name", self.censorship.tls_domain) "Invalid tls_domain: '{}'. Must be a valid domain name",
)); self.censorship.tls_domain
)));
} }
Ok(()) Ok(())
} }
} }

View File

@@ -6,8 +6,8 @@ use std::time::Duration;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::signal; use tokio::signal;
use tokio::sync::Semaphore; use tokio::sync::Semaphore;
use tracing::{info, error, warn, debug}; use tracing::{debug, error, info, warn};
use tracing_subscriber::{fmt, EnvFilter, reload, prelude::*}; use tracing_subscriber::{EnvFilter, fmt, prelude::*, reload};
mod cli; mod cli;
mod config; mod config;
@@ -20,14 +20,14 @@ mod stream;
mod transport; mod transport;
mod util; mod util;
use crate::config::{ProxyConfig, LogLevel}; use crate::config::{LogLevel, ProxyConfig};
use crate::proxy::ClientHandler;
use crate::stats::{Stats, ReplayChecker};
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::transport::{create_listener, ListenOptions, UpstreamManager}; use crate::proxy::ClientHandler;
use crate::transport::middle_proxy::MePool; use crate::stats::{ReplayChecker, Stats};
use crate::util::ip::detect_ip;
use crate::stream::BufferPool; use crate::stream::BufferPool;
use crate::transport::middle_proxy::MePool;
use crate::transport::{ListenOptions, UpstreamManager, create_listener};
use crate::util::ip::detect_ip;
fn parse_cli() -> (String, bool, Option<String>) { fn parse_cli() -> (String, bool, Option<String>) {
let mut config_path = "config.toml".to_string(); let mut config_path = "config.toml".to_string();
@@ -48,10 +48,14 @@ fn parse_cli() -> (String, bool, Option<String>) {
let mut i = 0; let mut i = 0;
while i < args.len() { while i < args.len() {
match args[i].as_str() { match args[i].as_str() {
"--silent" | "-s" => { silent = true; } "--silent" | "-s" => {
silent = true;
}
"--log-level" => { "--log-level" => {
i += 1; i += 1;
if i < args.len() { log_level = Some(args[i].clone()); } if i < args.len() {
log_level = Some(args[i].clone());
}
} }
s if s.starts_with("--log-level=") => { s if s.starts_with("--log-level=") => {
log_level = Some(s.trim_start_matches("--log-level=").to_string()); log_level = Some(s.trim_start_matches("--log-level=").to_string());
@@ -65,17 +69,27 @@ fn parse_cli() -> (String, bool, Option<String>) {
eprintln!(" --help, -h Show this help"); eprintln!(" --help, -h Show this help");
eprintln!(); eprintln!();
eprintln!("Setup (fire-and-forget):"); eprintln!("Setup (fire-and-forget):");
eprintln!(" --init Generate config, install systemd service, start"); eprintln!(
" --init Generate config, install systemd service, start"
);
eprintln!(" --port <PORT> Listen port (default: 443)"); eprintln!(" --port <PORT> Listen port (default: 443)");
eprintln!(" --domain <DOMAIN> TLS domain for masking (default: www.google.com)"); eprintln!(
eprintln!(" --secret <HEX> 32-char hex secret (auto-generated if omitted)"); " --domain <DOMAIN> TLS domain for masking (default: www.google.com)"
);
eprintln!(
" --secret <HEX> 32-char hex secret (auto-generated if omitted)"
);
eprintln!(" --user <NAME> Username (default: user)"); eprintln!(" --user <NAME> Username (default: user)");
eprintln!(" --config-dir <DIR> Config directory (default: /etc/telemt)"); eprintln!(" --config-dir <DIR> Config directory (default: /etc/telemt)");
eprintln!(" --no-start Don't start the service after install"); eprintln!(" --no-start Don't start the service after install");
std::process::exit(0); std::process::exit(0);
} }
s if !s.starts_with('-') => { config_path = s.to_string(); } s if !s.starts_with('-') => {
other => { eprintln!("Unknown option: {}", other); } config_path = s.to_string();
}
other => {
eprintln!("Unknown option: {}", other);
}
} }
i += 1; i += 1;
} }
@@ -124,21 +138,30 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION")); info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION"));
info!("Log level: {}", effective_log_level); info!("Log level: {}", effective_log_level);
info!("Modes: classic={} secure={} tls={}", info!(
config.general.modes.classic, "Modes: classic={} secure={} tls={}",
config.general.modes.secure, config.general.modes.classic, config.general.modes.secure, config.general.modes.tls
config.general.modes.tls); );
info!("TLS domain: {}", config.censorship.tls_domain); info!("TLS domain: {}", config.censorship.tls_domain);
if let Some(ref sock) = config.censorship.mask_unix_sock { if let Some(ref sock) = config.censorship.mask_unix_sock {
info!("Mask: {} -> unix:{}", config.censorship.mask, sock); info!("Mask: {} -> unix:{}", config.censorship.mask, sock);
if !std::path::Path::new(sock).exists() { if !std::path::Path::new(sock).exists() {
warn!("Unix socket '{}' does not exist yet. Masking will fail until it appears.", sock); warn!(
"Unix socket '{}' does not exist yet. Masking will fail until it appears.",
sock
);
} }
} else { } else {
info!("Mask: {} -> {}:{}", info!(
"Mask: {} -> {}:{}",
config.censorship.mask, config.censorship.mask,
config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain), config
config.censorship.mask_port); .censorship
.mask_host
.as_deref()
.unwrap_or(&config.censorship.tls_domain),
config.censorship.mask_port
);
} }
if config.censorship.tls_domain == "www.google.com" { if config.censorship.tls_domain == "www.google.com" {
@@ -166,69 +189,78 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
// Middle Proxy initialization (if enabled) // Middle Proxy initialization (if enabled)
// ===================================================================== // =====================================================================
let me_pool: Option<Arc<MePool>> = if use_middle_proxy { let me_pool: Option<Arc<MePool>> = if use_middle_proxy {
info!("=== Middle Proxy Mode ==="); info!("=== Middle Proxy Mode ===");
// ad_tag (proxy_tag) for advertising // ad_tag (proxy_tag) for advertising
let proxy_tag = config.general.ad_tag.as_ref().map(|tag| { let proxy_tag = config.general.ad_tag.as_ref().map(|tag| {
hex::decode(tag).unwrap_or_else(|_| { hex::decode(tag).unwrap_or_else(|_| {
warn!("Invalid ad_tag hex, middle proxy ad_tag will be empty"); warn!("Invalid ad_tag hex, middle proxy ad_tag will be empty");
Vec::new() Vec::new()
}) })
}); });
// ============================================================= // =============================================================
// CRITICAL: Download Telegram proxy-secret (NOT user secret!) // CRITICAL: Download Telegram proxy-secret (NOT user secret!)
// //
// C MTProxy uses TWO separate secrets: // C MTProxy uses TWO separate secrets:
// -S flag = 16-byte user secret for client obfuscation // -S flag = 16-byte user secret for client obfuscation
// --aes-pwd = 32-512 byte binary file for ME RPC auth // --aes-pwd = 32-512 byte binary file for ME RPC auth
// //
// proxy-secret is from: https://core.telegram.org/getProxySecret // proxy-secret is from: https://core.telegram.org/getProxySecret
// ============================================================= // =============================================================
let proxy_secret_path = config.general.proxy_secret_path.as_deref(); let proxy_secret_path = config.general.proxy_secret_path.as_deref();
match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).await { match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).await {
Ok(proxy_secret) => { Ok(proxy_secret) => {
info!( info!(
secret_len = proxy_secret.len(), secret_len = proxy_secret.len(),
key_sig = format_args!("0x{:08x}", key_sig = format_args!(
if proxy_secret.len() >= 4 { "0x{:08x}",
u32::from_le_bytes([proxy_secret[0], proxy_secret[1], if proxy_secret.len() >= 4 {
proxy_secret[2], proxy_secret[3]]) u32::from_le_bytes([
} else { 0 }), proxy_secret[0],
"Proxy-secret loaded" proxy_secret[1],
); proxy_secret[2],
proxy_secret[3],
let pool = MePool::new(proxy_tag, proxy_secret); ])
} else {
match pool.init(2, &rng).await { 0
Ok(()) => {
info!("Middle-End pool initialized successfully");
// Phase 4: Start health monitor
let pool_clone = pool.clone();
let rng_clone = rng.clone();
tokio::spawn(async move {
crate::transport::middle_proxy::me_health_monitor(
pool_clone, rng_clone, 2,
).await;
});
Some(pool)
}
Err(e) => {
error!(error = %e, "Failed to initialize ME pool. Falling back to direct mode.");
None
} }
),
"Proxy-secret loaded"
);
let pool = MePool::new(proxy_tag, proxy_secret, config.general.middle_proxy_nat_ip);
match pool.init(2, &rng).await {
Ok(()) => {
info!("Middle-End pool initialized successfully");
// Phase 4: Start health monitor
let pool_clone = pool.clone();
let rng_clone = rng.clone();
tokio::spawn(async move {
crate::transport::middle_proxy::me_health_monitor(
pool_clone, rng_clone, 2,
)
.await;
});
Some(pool)
}
Err(e) => {
error!(error = %e, "Failed to initialize ME pool. Falling back to direct mode.");
None
} }
} }
Err(e) => {
error!(error = %e, "Failed to fetch proxy-secret. Falling back to direct mode.");
None
}
} }
} else { Err(e) => {
None error!(error = %e, "Failed to fetch proxy-secret. Falling back to direct mode.");
}; None
}
}
} else {
None
};
if me_pool.is_some() { if me_pool.is_some() {
info!("Transport: Middle Proxy (supports all DCs including CDN)"); info!("Transport: Middle Proxy (supports all DCs including CDN)");
@@ -251,8 +283,14 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
info!(" IPv4 in use and IPv6 is fallback"); info!(" IPv4 in use and IPv6 is fallback");
} }
} else { } else {
let v6_works = upstream_result.v6_results.iter().any(|r| r.rtt_ms.is_some()); let v6_works = upstream_result
let v4_works = upstream_result.v4_results.iter().any(|r| r.rtt_ms.is_some()); .v6_results
.iter()
.any(|r| r.rtt_ms.is_some());
let v4_works = upstream_result
.v4_results
.iter()
.any(|r| r.rtt_ms.is_some());
if v6_works && !v4_works { if v6_works && !v4_works {
info!(" IPv6 only (IPv4 unavailable)"); info!(" IPv6 only (IPv4 unavailable)");
} else if v4_works && !v6_works { } else if v4_works && !v6_works {
@@ -290,11 +328,17 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
Some(rtt) => { Some(rtt) => {
// Align: IPv4 addresses are shorter, use more tabs // Align: IPv4 addresses are shorter, use more tabs
// 149.154.175.50:443 = ~18 chars // 149.154.175.50:443 = ~18 chars
info!(" DC{} [IPv4] {}:\t\t\t\t{:.0} ms", dc.dc_idx, addr_str, rtt); info!(
" DC{} [IPv4] {}:\t\t\t\t{:.0} ms",
dc.dc_idx, addr_str, rtt
);
} }
None => { None => {
let err = dc.error.as_deref().unwrap_or("fail"); let err = dc.error.as_deref().unwrap_or("fail");
info!(" DC{} [IPv4] {}:\t\t\t\tFAIL ({})", dc.dc_idx, addr_str, err); info!(
" DC{} [IPv4] {}:\t\t\t\tFAIL ({})",
dc.dc_idx, addr_str, err
);
} }
} }
} }
@@ -305,13 +349,20 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
// Background tasks // Background tasks
let um_clone = upstream_manager.clone(); let um_clone = upstream_manager.clone();
tokio::spawn(async move { um_clone.run_health_checks(prefer_ipv6).await; }); tokio::spawn(async move {
um_clone.run_health_checks(prefer_ipv6).await;
});
let rc_clone = replay_checker.clone(); let rc_clone = replay_checker.clone();
tokio::spawn(async move { rc_clone.run_periodic_cleanup().await; }); tokio::spawn(async move {
rc_clone.run_periodic_cleanup().await;
});
let detected_ip = detect_ip().await; let detected_ip = detect_ip().await;
debug!("Detected IPs: v4={:?} v6={:?}", detected_ip.ipv4, detected_ip.ipv6); debug!(
"Detected IPs: v4={:?} v6={:?}",
detected_ip.ipv4, detected_ip.ipv6
);
let mut listeners = Vec::new(); let mut listeners = Vec::new();
@@ -345,17 +396,23 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
if let Some(secret) = config.access.users.get(user_name) { if let Some(secret) = config.access.users.get(user_name) {
info!("User: {}", user_name); info!("User: {}", user_name);
if config.general.modes.classic { if config.general.modes.classic {
info!(" Classic: tg://proxy?server={}&port={}&secret={}", info!(
public_ip, config.server.port, secret); " Classic: tg://proxy?server={}&port={}&secret={}",
public_ip, config.server.port, secret
);
} }
if config.general.modes.secure { if config.general.modes.secure {
info!(" DD: tg://proxy?server={}&port={}&secret=dd{}", info!(
public_ip, config.server.port, secret); " DD: tg://proxy?server={}&port={}&secret=dd{}",
public_ip, config.server.port, secret
);
} }
if config.general.modes.tls { if config.general.modes.tls {
let domain_hex = hex::encode(&config.censorship.tls_domain); let domain_hex = hex::encode(&config.censorship.tls_domain);
info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", info!(
public_ip, config.server.port, secret, domain_hex); " EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
public_ip, config.server.port, secret, domain_hex
);
} }
} else { } else {
warn!("User '{}' in show_link not found", user_name); warn!("User '{}' in show_link not found", user_name);
@@ -365,7 +422,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
} }
listeners.push(listener); listeners.push(listener);
}, }
Err(e) => { Err(e) => {
error!("Failed to bind to {}: {}", addr, e); error!("Failed to bind to {}: {}", addr, e);
} }
@@ -383,7 +440,9 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
} else { } else {
EnvFilter::new(effective_log_level.to_filter_str()) EnvFilter::new(effective_log_level.to_filter_str())
}; };
filter_handle.reload(runtime_filter).expect("Failed to switch log filter"); filter_handle
.reload(runtime_filter)
.expect("Failed to switch log filter");
for listener in listeners { for listener in listeners {
let config = config.clone(); let config = config.clone();
@@ -408,10 +467,19 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = ClientHandler::new( if let Err(e) = ClientHandler::new(
stream, peer_addr, config, stats, stream,
upstream_manager, replay_checker, buffer_pool, rng, peer_addr,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
me_pool, me_pool,
).run().await { )
.run()
.await
{
debug!(peer = %peer_addr, error = %e, "Connection error"); debug!(peer = %peer_addr, error = %e, "Connection error");
} }
}); });
@@ -431,4 +499,4 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
} }
Ok(()) Ok(())
} }

View File

@@ -1,661 +1,354 @@
//! Client Handler //! Client Handler
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::net::TcpStream; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream;
use tokio::time::timeout; use tokio::time::timeout;
use tracing::{debug, info, warn, error, trace}; use tracing::{debug, warn};
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::error::{ProxyError, Result, HandshakeResult}; use crate::crypto::SecureRandom;
use crate::protocol::constants::*; use crate::error::{HandshakeResult, ProxyError, Result};
use crate::protocol::tls; use crate::protocol::constants::*;
use crate::stats::{Stats, ReplayChecker}; use crate::protocol::tls;
use crate::transport::{configure_client_socket, UpstreamManager}; use crate::stats::{ReplayChecker, Stats};
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter, BufferPool}; use crate::transport::middle_proxy::MePool;
use crate::crypto::{AesCtr, SecureRandom}; use crate::transport::{UpstreamManager, configure_client_socket};
use crate::proxy::handshake::{ use crate::proxy::direct_relay::handle_via_direct;
handle_tls_handshake, handle_mtproto_handshake, use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake};
HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce_with_ciphers, use crate::proxy::masking::handle_bad_client;
}; use crate::proxy::middle_relay::handle_via_middle_proxy;
use crate::proxy::relay::relay_bidirectional;
use crate::proxy::masking::handle_bad_client; pub struct ClientHandler;
pub struct ClientHandler; pub struct RunningClientHandler {
stream: TcpStream,
pub struct RunningClientHandler { peer: SocketAddr,
config: Arc<ProxyConfig>,
stats: Arc<Stats>,
replay_checker: Arc<ReplayChecker>,
upstream_manager: Arc<UpstreamManager>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
}
impl ClientHandler {
pub fn new(
stream: TcpStream, stream: TcpStream,
peer: SocketAddr, peer: SocketAddr,
config: Arc<ProxyConfig>, config: Arc<ProxyConfig>,
stats: Arc<Stats>, stats: Arc<Stats>,
replay_checker: Arc<ReplayChecker>,
upstream_manager: Arc<UpstreamManager>, upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>, buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>, rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>, me_pool: Option<Arc<MePool>>,
) -> RunningClientHandler {
RunningClientHandler {
stream,
peer,
config,
stats,
replay_checker,
upstream_manager,
buffer_pool,
rng,
me_pool,
}
} }
}
impl ClientHandler {
pub fn new( impl RunningClientHandler {
stream: TcpStream, pub async fn run(mut self) -> Result<()> {
peer: SocketAddr, self.stats.increment_connects_all();
config: Arc<ProxyConfig>,
stats: Arc<Stats>, let peer = self.peer;
upstream_manager: Arc<UpstreamManager>, debug!(peer = %peer, "New connection");
replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>, if let Err(e) = configure_client_socket(
rng: Arc<SecureRandom>, &self.stream,
me_pool: Option<Arc<MePool>>, self.config.timeouts.client_keepalive,
) -> RunningClientHandler { self.config.timeouts.client_ack,
RunningClientHandler { ) {
stream, peer, config, stats, replay_checker, debug!(peer = %peer, error = %e, "Failed to configure client socket");
upstream_manager, buffer_pool, rng, me_pool, }
let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake);
let stats = self.stats.clone();
let result = timeout(handshake_timeout, self.do_handshake()).await;
match result {
Ok(Ok(())) => {
debug!(peer = %peer, "Connection handled successfully");
Ok(())
}
Ok(Err(e)) => {
debug!(peer = %peer, error = %e, "Handshake failed");
Err(e)
}
Err(_) => {
stats.increment_handshake_timeouts();
debug!(peer = %peer, "Handshake timeout");
Err(ProxyError::TgHandshakeTimeout)
} }
} }
} }
impl RunningClientHandler { async fn do_handshake(mut self) -> Result<()> {
pub async fn run(mut self) -> Result<()> { let mut first_bytes = [0u8; 5];
self.stats.increment_connects_all(); self.stream.read_exact(&mut first_bytes).await?;
let peer = self.peer; let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
debug!(peer = %peer, "New connection"); let peer = self.peer;
if let Err(e) = configure_client_socket( debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
&self.stream,
self.config.timeouts.client_keepalive, if is_tls {
self.config.timeouts.client_ack, self.handle_tls_client(first_bytes).await
) { } else {
debug!(peer = %peer, error = %e, "Failed to configure client socket"); self.handle_direct_client(first_bytes).await
}
let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake);
let stats = self.stats.clone();
let result = timeout(handshake_timeout, self.do_handshake()).await;
match result {
Ok(Ok(())) => {
debug!(peer = %peer, "Connection handled successfully");
Ok(())
}
Ok(Err(e)) => {
debug!(peer = %peer, error = %e, "Handshake failed");
Err(e)
}
Err(_) => {
stats.increment_handshake_timeouts();
debug!(peer = %peer, "Handshake timeout");
Err(ProxyError::TgHandshakeTimeout)
}
}
} }
}
async fn do_handshake(mut self) -> Result<()> {
let mut first_bytes = [0u8; 5]; async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
self.stream.read_exact(&mut first_bytes).await?; let peer = self.peer;
let is_tls = tls::is_tls_handshake(&first_bytes[..3]); let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
let peer = self.peer;
debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake");
debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
if tls_len < 512 {
if is_tls { debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
self.handle_tls_client(first_bytes).await self.stats.increment_connects_bad();
} else { let (reader, writer) = self.stream.into_split();
self.handle_direct_client(first_bytes).await handle_bad_client(reader, writer, &first_bytes, &self.config).await;
} return Ok(());
} }
async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> { let mut handshake = vec![0u8; 5 + tls_len];
let peer = self.peer; handshake[..5].copy_from_slice(&first_bytes);
self.stream.read_exact(&mut handshake[5..]).await?;
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
let config = self.config.clone();
debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake"); let replay_checker = self.replay_checker.clone();
let stats = self.stats.clone();
if tls_len < 512 { let buffer_pool = self.buffer_pool.clone();
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
self.stats.increment_connects_bad(); let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?;
let (reader, writer) = self.stream.into_split(); let (read_half, write_half) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
&handshake,
read_half,
write_half,
peer,
&config,
&replay_checker,
&self.rng,
)
.await
{
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(()); return Ok(());
} }
HandshakeResult::Error(e) => return Err(e),
let mut handshake = vec![0u8; 5 + tls_len]; };
handshake[..5].copy_from_slice(&first_bytes);
self.stream.read_exact(&mut handshake[5..]).await?; debug!(peer = %peer, "Reading MTProto handshake through TLS");
let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?;
let config = self.config.clone(); let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..]
let replay_checker = self.replay_checker.clone(); .try_into()
let stats = self.stats.clone(); .map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
let buffer_pool = self.buffer_pool.clone();
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
let (read_half, write_half) = self.stream.into_split(); &mtproto_handshake,
tls_reader,
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( tls_writer,
&handshake, read_half, write_half, peer, peer,
&config, &replay_checker, &self.rng, &config,
).await { &replay_checker,
HandshakeResult::Success(result) => result, true,
HandshakeResult::BadClient { reader, writer } => { )
stats.increment_connects_bad(); .await
handle_bad_client(reader, writer, &handshake, &config).await; {
return Ok(()); HandshakeResult::Success(result) => result,
} HandshakeResult::BadClient {
HandshakeResult::Error(e) => return Err(e), reader: _,
}; writer: _,
} => {
debug!(peer = %peer, "Reading MTProto handshake through TLS"); stats.increment_connects_bad();
let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?; debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into()
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
&mtproto_handshake, tls_reader, tls_writer, peer,
&config, &replay_checker, true,
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader: _, writer: _ } => {
stats.increment_connects_bad();
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
return Ok(());
}
HandshakeResult::Error(e) => return Err(e),
};
Self::handle_authenticated_static(
crypto_reader, crypto_writer, success,
self.upstream_manager, self.stats, self.config,
buffer_pool, self.rng, self.me_pool,
).await
}
async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
let peer = self.peer;
if !self.config.general.modes.classic && !self.config.general.modes.secure {
debug!(peer = %peer, "Non-TLS modes disabled");
self.stats.increment_connects_bad();
let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(()); return Ok(());
} }
HandshakeResult::Error(e) => return Err(e),
let mut handshake = [0u8; HANDSHAKE_LEN]; };
handshake[..5].copy_from_slice(&first_bytes);
self.stream.read_exact(&mut handshake[5..]).await?; Self::handle_authenticated_static(
crypto_reader,
let config = self.config.clone(); crypto_writer,
let replay_checker = self.replay_checker.clone(); success,
let stats = self.stats.clone(); self.upstream_manager,
let buffer_pool = self.buffer_pool.clone(); self.stats,
self.config,
let (read_half, write_half) = self.stream.into_split(); buffer_pool,
self.rng,
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( self.me_pool,
&handshake, read_half, write_half, peer, local_addr,
&config, &replay_checker, false, )
).await { .await
HandshakeResult::Success(result) => result, }
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad(); async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
handle_bad_client(reader, writer, &handshake, &config).await; let peer = self.peer;
return Ok(());
} if !self.config.general.modes.classic && !self.config.general.modes.secure {
HandshakeResult::Error(e) => return Err(e), debug!(peer = %peer, "Non-TLS modes disabled");
}; self.stats.increment_connects_bad();
let (reader, writer) = self.stream.into_split();
Self::handle_authenticated_static( handle_bad_client(reader, writer, &first_bytes, &self.config).await;
crypto_reader, crypto_writer, success, return Ok(());
self.upstream_manager, self.stats, self.config,
buffer_pool, self.rng, self.me_pool,
).await
} }
/// Main dispatch after successful handshake. let mut handshake = [0u8; HANDSHAKE_LEN];
/// Two modes: handshake[..5].copy_from_slice(&first_bytes);
/// - Direct: TCP relay to TG DC (existing behavior) self.stream.read_exact(&mut handshake[5..]).await?;
/// - Middle Proxy: RPC multiplex through ME pool (new — supports CDN DCs)
async fn handle_authenticated_static<R, W>( let config = self.config.clone();
client_reader: CryptoReader<R>, let replay_checker = self.replay_checker.clone();
client_writer: CryptoWriter<W>, let stats = self.stats.clone();
success: HandshakeSuccess, let buffer_pool = self.buffer_pool.clone();
upstream_manager: Arc<UpstreamManager>,
stats: Arc<Stats>, let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?;
config: Arc<ProxyConfig>, let (read_half, write_half) = self.stream.into_split();
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>, let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
me_pool: Option<Arc<MePool>>, &handshake,
) -> Result<()> read_half,
where write_half,
R: AsyncRead + Unpin + Send + 'static, peer,
W: AsyncWrite + Unpin + Send + 'static, &config,
&replay_checker,
false,
)
.await
{ {
let user = &success.user; HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
if let Err(e) = Self::check_user_limits_static(user, &config, &stats) { stats.increment_connects_bad();
warn!(user = %user, error = %e, "User limit exceeded"); handle_bad_client(reader, writer, &handshake, &config).await;
return Err(e); return Ok(());
} }
HandshakeResult::Error(e) => return Err(e),
// Decide: middle proxy or direct };
if config.general.use_middle_proxy {
if let Some(ref pool) = me_pool { Self::handle_authenticated_static(
return Self::handle_via_middle_proxy( crypto_reader,
client_reader, client_writer, success, crypto_writer,
pool.clone(), stats, config, buffer_pool, success,
).await; self.upstream_manager,
} self.stats,
warn!("use_middle_proxy=true but MePool not initialized, falling back to direct"); self.config,
} buffer_pool,
self.rng,
// Direct mode (original behavior) self.me_pool,
Self::handle_via_direct( local_addr,
client_reader, client_writer, success, )
upstream_manager, stats, config, buffer_pool, rng, .await
).await }
}
/// Main dispatch after successful handshake.
// ===================================================================== /// Two modes:
// Direct mode — TCP relay to Telegram DC /// - Direct: TCP relay to TG DC (existing behavior)
// ===================================================================== /// - Middle Proxy: RPC multiplex through ME pool (new — supports CDN DCs)
async fn handle_authenticated_static<R, W>(
async fn handle_via_direct<R, W>( client_reader: CryptoReader<R>,
client_reader: CryptoReader<R>, client_writer: CryptoWriter<W>,
client_writer: CryptoWriter<W>,
success: HandshakeSuccess,
upstream_manager: Arc<UpstreamManager>,
stats: Arc<Stats>,
config: Arc<ProxyConfig>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
) -> Result<()>
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
let user = &success.user;
let dc_addr = Self::get_dc_addr_static(success.dc_idx, &config)?;
info!(
user = %user,
peer = %success.peer,
dc = success.dc_idx,
dc_addr = %dc_addr,
proto = ?success.proto_tag,
mode = "direct",
"Connecting to Telegram DC"
);
let tg_stream = upstream_manager.connect(dc_addr, Some(success.dc_idx)).await?;
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake");
let (tg_reader, tg_writer) = Self::do_tg_handshake_static(
tg_stream, &success, &config, rng.as_ref(),
).await?;
debug!(peer = %success.peer, "TG handshake complete, starting relay");
stats.increment_user_connects(user);
stats.increment_user_curr_connects(user);
let relay_result = relay_bidirectional(
client_reader, client_writer,
tg_reader, tg_writer,
user, Arc::clone(&stats), buffer_pool,
).await;
stats.decrement_user_curr_connects(user);
match &relay_result {
Ok(()) => debug!(user = %user, "Direct relay completed"),
Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"),
}
relay_result
}
// =====================================================================
// Middle Proxy mode — RPC multiplex through ME pool
// =====================================================================
/// Middle Proxy RPC relay
///
/// Architecture (matches C MTProxy):
/// ```text
/// Client ←AES-CTR→ [telemt] ←RPC/AES-CBC→ ME ←internal→ DC (any, incl CDN 203)
/// ```
///
/// Key difference from direct mode:
/// - No per-client TCP to DC; all clients share ME pool connections
/// - ME internally routes to correct DC based on client's encrypted auth_key_id
/// - CDN DCs (203+) work because ME knows their internal addresses
/// - We pass raw client MTProto bytes in RPC_PROXY_REQ envelope
/// - ME returns responses in RPC_PROXY_ANS envelope
async fn handle_via_middle_proxy<R, W>(
crypto_reader: CryptoReader<R>,
crypto_writer: CryptoWriter<W>,
success: HandshakeSuccess, success: HandshakeSuccess,
me_pool: Arc<MePool>, upstream_manager: Arc<UpstreamManager>,
stats: Arc<Stats>, stats: Arc<Stats>,
config: Arc<ProxyConfig>, config: Arc<ProxyConfig>,
_buffer_pool: Arc<BufferPool>, buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>,
local_addr: SocketAddr,
) -> Result<()> ) -> Result<()>
where where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static,
{ {
let mut client_reader = crypto_reader; let user = &success.user;
let mut client_writer = crypto_writer;
let user = success.user.clone();
let peer = success.peer;
info!( if let Err(e) = Self::check_user_limits_static(user, &config, &stats) {
user = %user, warn!(user = %user, error = %e, "User limit exceeded");
peer = %peer, return Err(e);
dc = success.dc_idx, }
proto = ?success.proto_tag,
mode = "middle_proxy",
"Routing via Middle-End"
);
let (conn_id, mut me_rx) = me_pool.registry().register().await; // Decide: middle proxy or direct
if config.general.use_middle_proxy {
let our_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port) if let Some(ref pool) = me_pool {
.parse().unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap()); return handle_via_middle_proxy(
client_reader,
stats.increment_user_connects(&user); client_writer,
stats.increment_user_curr_connects(&user); success,
pool.clone(),
let proto_flags = proto_flags_for_tag(success.proto_tag); stats,
debug!(user = %user, conn_id, proto_flags = format_args!("0x{:08x}", proto_flags), "ME relay started"); config,
buffer_pool,
// We need to handle framing here. local_addr,
// Client sends: [Len:4][Payload...] (Intermediate/Secure) )
// We must strip Len and send Payload to ME. .await;
// ME sends: [Payload...]
// We must add [Len:4] and send to Client.
// For Secure mode, Len has padding bit (MSB).
let is_secure = success.proto_tag == crate::protocol::constants::ProtoTag::Secure;
let mut client_closed = false;
let mut server_closed = false;
// Split client_reader/writer to use in select!
// CryptoReader/Writer don't support splitting easily without Arc/Mutex or unsafe,
// but here we are in a loop.
// We can't easily split them because they wrap the underlying stream.
// However, we can use a loop with select! on read and rx.
let mut len_buf = [0u8; 4];
let mut reading_len = true;
let mut current_payload_len = 0;
let mut payload_buf = Vec::new();
let result: Result<()> = loop {
tokio::select! {
// C->S: Read length, then payload
res = async {
if reading_len {
client_reader.read_exact(&mut len_buf).await.map(|_| true)
} else {
// Read payload
// We need to read exactly current_payload_len
if payload_buf.len() < current_payload_len {
let needed = current_payload_len - payload_buf.len();
let mut chunk = vec![0u8; needed];
let n = client_reader.read(&mut chunk).await?;
if n == 0 { return Ok(false); } // EOF
payload_buf.extend_from_slice(&chunk[..n]);
Ok(true)
} else {
Ok(true) // Should not happen
}
}
}, if !client_closed => {
match res {
Ok(true) => {
if reading_len {
// Got length
let raw_len = u32::from_le_bytes(len_buf);
// In secure mode, MSB is padding flag. In intermediate, it's just len.
// But wait, standard intermediate doesn't use MSB for padding.
// Secure mode DOES.
// Let's trust the protocol tag.
let len = if is_secure {
raw_len & 0x7FFFFFFF
} else {
raw_len
};
current_payload_len = len as usize;
// Sanity check
if current_payload_len > 16 * 1024 * 1024 {
debug!(conn_id, len=current_payload_len, "Client sent huge frame");
break Err(ProxyError::Proxy("Frame too large".into()));
}
payload_buf.clear();
payload_buf.reserve(current_payload_len);
reading_len = false;
} else {
// Got some payload data
if payload_buf.len() == current_payload_len {
// Full frame received
trace!(conn_id, bytes = current_payload_len, "C->ME (Frame complete)");
stats.add_user_octets_from(&user, current_payload_len as u64);
// Send to ME
// Note: In secure mode, we send the PADDING bytes too?
// Erlang mtp_intermediate: strips 4 bytes len.
// Erlang mtp_secure: strips 4 bytes len.
// The payload includes the padding if it was added?
// Actually, secure layer (mtp_secure.erl) handles padding removal?
// No, mtp_secure just sets padding=>true for intermediate codec.
// The intermediate codec (mtp_intermediate.erl) just extracts the packet.
// The packet passed to RPC is the payload.
// If secure mode adds random padding at the end, it is part of the payload
// that ME receives?
// Let's look at C code.
// ext-server.c: reads packet_len.
// if (packet_len & 0x80000000) -> has padding.
// It reads the full packet.
// Then it passes it to forward_tcp_query.
// So YES, we send the full payload including padding to ME.
if let Err(e) = me_pool.send_proxy_req(
conn_id, peer, our_addr, &payload_buf, proto_flags
).await {
break Err(e);
}
// Reset for next frame
reading_len = true;
}
}
}
Ok(false) => {
// EOF
debug!(conn_id, "Client EOF");
client_closed = true;
let _ = me_pool.send_close(conn_id).await;
if server_closed { break Ok(()); }
}
Err(e) => {
debug!(conn_id, error = %e, "Client read error");
break Err(ProxyError::Io(e));
}
}
}
// S->C: ME sends data, we wrap and send to client
me_msg = me_rx.recv(), if !server_closed => {
match me_msg {
Some(MeResponse::Data(data)) => {
trace!(conn_id, bytes = data.len(), "ME->C");
stats.add_user_octets_to(&user, data.len() as u64);
// Wrap in intermediate frame
let len = data.len() as u32;
// For secure mode, we might need to add padding?
// C code: forward_mtproto_packet -> just sends data.
// But wait, C code adds framing in net-tcp-rpc-ext-server.c?
// No, forward_tcp_query sends RPC_PROXY_REQ.
// ME sends RPC_PROXY_ANS.
// The data in ANS is the MTProto packet.
// We need to send it to client.
// If client is Intermediate/Secure, we MUST add the 4-byte length prefix.
// Secure mode: usually we don't ADD padding on response, we just send valid packets.
// But we MUST send the length.
if let Err(e) = client_writer.write_all(&len.to_le_bytes()).await {
break Err(ProxyError::Io(e));
}
if let Err(e) = client_writer.write_all(&data).await {
break Err(ProxyError::Io(e));
}
if let Err(e) = client_writer.flush().await {
break Err(ProxyError::Io(e));
}
}
Some(MeResponse::Ack(_)) => {
trace!(conn_id, "ME ACK");
}
Some(MeResponse::Close) => {
debug!(conn_id, "ME sent CLOSE");
server_closed = true;
if client_closed { break Ok(()); }
// We should probably close client connection too
break Ok(());
}
None => {
debug!(conn_id, "ME channel closed");
server_closed = true;
if client_closed { break Ok(()); }
break Err(ProxyError::Proxy("ME connection lost".into()));
}
}
}
} }
}; warn!("use_middle_proxy=true but MePool not initialized, falling back to direct");
}
// Cleanup // Direct mode (original behavior)
debug!(user = %user, conn_id, "ME relay cleanup"); handle_via_direct(
me_pool.registry().unregister(conn_id).await; client_reader,
stats.decrement_user_curr_connects(&user); client_writer,
result success,
upstream_manager,
stats,
config,
buffer_pool,
rng,
)
.await
} }
fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> { fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> {
if let Some(expiration) = config.access.user_expirations.get(user) { if let Some(expiration) = config.access.user_expirations.get(user) {
if chrono::Utc::now() > *expiration { if chrono::Utc::now() > *expiration {
return Err(ProxyError::UserExpired { user: user.to_string() }); return Err(ProxyError::UserExpired {
} user: user.to_string(),
});
} }
if let Some(limit) = config.access.user_max_tcp_conns.get(user) {
if stats.get_user_curr_connects(user) >= *limit as u64 {
return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() });
}
}
if let Some(quota) = config.access.user_data_quota.get(user) {
if stats.get_user_total_octets(user) >= *quota {
return Err(ProxyError::DataQuotaExceeded { user: user.to_string() });
}
}
Ok(())
} }
/// Resolve DC index to target address (used only in direct mode)
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
let datacenters = if config.general.prefer_ipv6 {
&*TG_DATACENTERS_V6
} else {
&*TG_DATACENTERS_V4
};
let num_dcs = datacenters.len();
let dc_key = dc_idx.to_string();
if let Some(addr_str) = config.dc_overrides.get(&dc_key) {
match addr_str.parse::<SocketAddr>() {
Ok(addr) => {
debug!(dc_idx = dc_idx, addr = %addr, "Using DC override from config");
return Ok(addr);
}
Err(_) => {
warn!(dc_idx = dc_idx, addr_str = %addr_str,
"Invalid DC override address in config, ignoring");
}
}
}
let abs_dc = dc_idx.unsigned_abs() as usize;
if abs_dc >= 1 && abs_dc <= num_dcs {
return Ok(SocketAddr::new(datacenters[abs_dc - 1], TG_DATACENTER_PORT));
}
let default_dc = config.default_dc.unwrap_or(2) as usize;
let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs {
default_dc - 1
} else {
1
};
info!(
original_dc = dc_idx,
fallback_dc = (fallback_idx + 1) as u16,
fallback_addr = %datacenters[fallback_idx],
"Special DC ---> default_cluster"
);
Ok(SocketAddr::new(datacenters[fallback_idx], TG_DATACENTER_PORT))
}
/// Perform obfuscated handshake with Telegram DC (direct mode only)
async fn do_tg_handshake_static(
mut stream: TcpStream,
success: &HandshakeSuccess,
config: &ProxyConfig,
rng: &SecureRandom,
) -> Result<(CryptoReader<tokio::net::tcp::OwnedReadHalf>, CryptoWriter<tokio::net::tcp::OwnedWriteHalf>)> {
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce(
success.proto_tag,
success.dc_idx,
&success.dec_key,
success.dec_iv,
rng,
config.general.fast_mode,
);
let (encrypted_nonce, mut tg_encryptor, tg_decryptor) = encrypt_tg_nonce_with_ciphers(&nonce); if let Some(limit) = config.access.user_max_tcp_conns.get(user) {
if stats.get_user_curr_connects(user) >= *limit as u64 {
debug!( return Err(ProxyError::ConnectionLimitExceeded {
peer = %success.peer, user: user.to_string(),
nonce_head = %hex::encode(&nonce[..16]), });
"Sending nonce to Telegram" }
);
stream.write_all(&encrypted_nonce).await?;
stream.flush().await?;
let (read_half, write_half) = stream.into_split();
Ok((
CryptoReader::new(read_half, tg_decryptor),
CryptoWriter::new(write_half, tg_encryptor),
))
} }
if let Some(quota) = config.access.user_data_quota.get(user) {
if stats.get_user_total_octets(user) >= *quota {
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
}
Ok(())
} }
}

163
src/proxy/direct_relay.rs Normal file
View File

@@ -0,0 +1,163 @@
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::{debug, info, warn};
use crate::config::ProxyConfig;
use crate::crypto::SecureRandom;
use crate::error::Result;
use crate::protocol::constants::*;
use crate::proxy::handshake::{HandshakeSuccess, encrypt_tg_nonce_with_ciphers, generate_tg_nonce};
use crate::proxy::relay::relay_bidirectional;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::UpstreamManager;
pub(crate) async fn handle_via_direct<R, W>(
client_reader: CryptoReader<R>,
client_writer: CryptoWriter<W>,
success: HandshakeSuccess,
upstream_manager: Arc<UpstreamManager>,
stats: Arc<Stats>,
config: Arc<ProxyConfig>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
) -> Result<()>
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
let user = &success.user;
let dc_addr = get_dc_addr_static(success.dc_idx, &config)?;
info!(
user = %user,
peer = %success.peer,
dc = success.dc_idx,
dc_addr = %dc_addr,
proto = ?success.proto_tag,
mode = "direct",
"Connecting to Telegram DC"
);
let tg_stream = upstream_manager
.connect(dc_addr, Some(success.dc_idx))
.await?;
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake");
let (tg_reader, tg_writer) =
do_tg_handshake_static(tg_stream, &success, &config, rng.as_ref()).await?;
debug!(peer = %success.peer, "TG handshake complete, starting relay");
stats.increment_user_connects(user);
stats.increment_user_curr_connects(user);
let relay_result = relay_bidirectional(
client_reader,
client_writer,
tg_reader,
tg_writer,
user,
Arc::clone(&stats),
buffer_pool,
)
.await;
stats.decrement_user_curr_connects(user);
match &relay_result {
Ok(()) => debug!(user = %user, "Direct relay completed"),
Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"),
}
relay_result
}
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
let datacenters = if config.general.prefer_ipv6 {
&*TG_DATACENTERS_V6
} else {
&*TG_DATACENTERS_V4
};
let num_dcs = datacenters.len();
let dc_key = dc_idx.to_string();
if let Some(addr_str) = config.dc_overrides.get(&dc_key) {
match addr_str.parse::<SocketAddr>() {
Ok(addr) => {
debug!(dc_idx = dc_idx, addr = %addr, "Using DC override from config");
return Ok(addr);
}
Err(_) => {
warn!(dc_idx = dc_idx, addr_str = %addr_str,
"Invalid DC override address in config, ignoring");
}
}
}
let abs_dc = dc_idx.unsigned_abs() as usize;
if abs_dc >= 1 && abs_dc <= num_dcs {
return Ok(SocketAddr::new(datacenters[abs_dc - 1], TG_DATACENTER_PORT));
}
let default_dc = config.default_dc.unwrap_or(2) as usize;
let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs {
default_dc - 1
} else {
1
};
info!(
original_dc = dc_idx,
fallback_dc = (fallback_idx + 1) as u16,
fallback_addr = %datacenters[fallback_idx],
"Special DC ---> default_cluster"
);
Ok(SocketAddr::new(
datacenters[fallback_idx],
TG_DATACENTER_PORT,
))
}
async fn do_tg_handshake_static(
mut stream: TcpStream,
success: &HandshakeSuccess,
config: &ProxyConfig,
rng: &SecureRandom,
) -> Result<(
CryptoReader<tokio::net::tcp::OwnedReadHalf>,
CryptoWriter<tokio::net::tcp::OwnedWriteHalf>,
)> {
let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce(
success.proto_tag,
success.dc_idx,
&success.dec_key,
success.dec_iv,
rng,
config.general.fast_mode,
);
let (encrypted_nonce, tg_encryptor, tg_decryptor) = encrypt_tg_nonce_with_ciphers(&nonce);
debug!(
peer = %success.peer,
nonce_head = %hex::encode(&nonce[..16]),
"Sending nonce to Telegram"
);
stream.write_all(&encrypted_nonce).await?;
stream.flush().await?;
let (read_half, write_half) = stream.into_split();
Ok((
CryptoReader::new(read_half, tg_decryptor),
CryptoWriter::new(write_half, tg_encryptor),
))
}

247
src/proxy/middle_relay.rs Normal file
View File

@@ -0,0 +1,247 @@
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tracing::{debug, info, trace};
use crate::config::ProxyConfig;
use crate::error::{ProxyError, Result};
use crate::protocol::constants::*;
use crate::proxy::handshake::HandshakeSuccess;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
pub(crate) async fn handle_via_middle_proxy<R, W>(
mut crypto_reader: CryptoReader<R>,
mut crypto_writer: CryptoWriter<W>,
success: HandshakeSuccess,
me_pool: Arc<MePool>,
stats: Arc<Stats>,
_config: Arc<ProxyConfig>,
_buffer_pool: Arc<BufferPool>,
local_addr: SocketAddr,
) -> Result<()>
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
let user = success.user.clone();
let peer = success.peer;
let proto_tag = success.proto_tag;
info!(
user = %user,
peer = %peer,
dc = success.dc_idx,
proto = ?proto_tag,
mode = "middle_proxy",
"Routing via Middle-End"
);
let (conn_id, mut me_rx) = me_pool.registry().register().await;
stats.increment_user_connects(&user);
stats.increment_user_curr_connects(&user);
let proto_flags = proto_flags_for_tag(proto_tag, me_pool.has_proxy_tag());
debug!(
user = %user,
conn_id,
proto_flags = format_args!("0x{:08x}", proto_flags),
"ME relay started"
);
let translated_local_addr = me_pool.translate_our_addr(local_addr);
let result: Result<()> = loop {
tokio::select! {
client_frame = read_client_payload(&mut crypto_reader, proto_tag) => {
match client_frame {
Ok(Some(payload)) => {
trace!(conn_id, bytes = payload.len(), "C->ME frame");
stats.add_user_octets_from(&user, payload.len() as u64);
me_pool.send_proxy_req(conn_id, peer, translated_local_addr, &payload, proto_flags).await?;
}
Ok(None) => {
debug!(conn_id, "Client EOF");
let _ = me_pool.send_close(conn_id).await;
break Ok(());
}
Err(e) => break Err(e),
}
}
me_msg = me_rx.recv() => {
match me_msg {
Some(MeResponse::Data { flags, data }) => {
trace!(conn_id, bytes = data.len(), flags, "ME->C data");
stats.add_user_octets_to(&user, data.len() as u64);
write_client_payload(&mut crypto_writer, proto_tag, flags, &data).await?;
}
Some(MeResponse::Ack(confirm)) => {
trace!(conn_id, confirm, "ME->C quickack");
write_client_ack(&mut crypto_writer, proto_tag, confirm).await?;
}
Some(MeResponse::Close) => {
debug!(conn_id, "ME sent close");
break Ok(());
}
None => {
debug!(conn_id, "ME channel closed");
break Err(ProxyError::Proxy("ME connection lost".into()));
}
}
}
}
};
debug!(user = %user, conn_id, "ME relay cleanup");
me_pool.registry().unregister(conn_id).await;
stats.decrement_user_curr_connects(&user);
result
}
async fn read_client_payload<R>(
client_reader: &mut CryptoReader<R>,
proto_tag: ProtoTag,
) -> Result<Option<Vec<u8>>>
where
R: AsyncRead + Unpin + Send + 'static,
{
let len = match proto_tag {
ProtoTag::Abridged => {
let mut first = [0u8; 1];
match client_reader.read_exact(&mut first).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ProxyError::Io(e)),
}
let len_words = if (first[0] & 0x7f) == 0x7f {
let mut ext = [0u8; 3];
client_reader
.read_exact(&mut ext)
.await
.map_err(ProxyError::Io)?;
u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize
} else {
(first[0] & 0x7f) as usize
};
len_words
.checked_mul(4)
.ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))?
}
ProtoTag::Intermediate | ProtoTag::Secure => {
let mut len_buf = [0u8; 4];
match client_reader.read_exact(&mut len_buf).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ProxyError::Io(e)),
}
(u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize
}
};
if len > 16 * 1024 * 1024 {
return Err(ProxyError::Proxy(format!("Frame too large: {len}")));
}
let mut payload = vec![0u8; len];
client_reader
.read_exact(&mut payload)
.await
.map_err(ProxyError::Io)?;
Ok(Some(payload))
}
async fn write_client_payload<W>(
client_writer: &mut CryptoWriter<W>,
proto_tag: ProtoTag,
flags: u32,
data: &[u8],
) -> Result<()>
where
W: AsyncWrite + Unpin + Send + 'static,
{
let quickack = (flags & RPC_FLAG_QUICKACK) != 0;
match proto_tag {
ProtoTag::Abridged => {
if data.len() % 4 != 0 {
return Err(ProxyError::Proxy(format!(
"Abridged payload must be 4-byte aligned, got {}",
data.len()
)));
}
let len_words = data.len() / 4;
if len_words < 0x7f {
let mut first = len_words as u8;
if quickack {
first |= 0x80;
}
client_writer
.write_all(&[first])
.await
.map_err(ProxyError::Io)?;
} else if len_words < (1 << 24) {
let mut first = 0x7fu8;
if quickack {
first |= 0x80;
}
let lw = (len_words as u32).to_le_bytes();
client_writer
.write_all(&[first, lw[0], lw[1], lw[2]])
.await
.map_err(ProxyError::Io)?;
} else {
return Err(ProxyError::Proxy(format!(
"Abridged frame too large: {}",
data.len()
)));
}
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
}
ProtoTag::Intermediate | ProtoTag::Secure => {
let mut len = data.len() as u32;
if quickack {
len |= 0x8000_0000;
}
client_writer
.write_all(&len.to_le_bytes())
.await
.map_err(ProxyError::Io)?;
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
}
}
client_writer.flush().await.map_err(ProxyError::Io)
}
async fn write_client_ack<W>(
client_writer: &mut CryptoWriter<W>,
proto_tag: ProtoTag,
confirm: u32,
) -> Result<()>
where
W: AsyncWrite + Unpin + Send + 'static,
{
let bytes = if proto_tag == ProtoTag::Abridged {
confirm.to_be_bytes()
} else {
confirm.to_le_bytes()
};
client_writer
.write_all(&bytes)
.await
.map_err(ProxyError::Io)?;
client_writer.flush().await.map_err(ProxyError::Io)
}

View File

@@ -0,0 +1,178 @@
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::crypto::{AesCbc, crc32};
use crate::error::{ProxyError, Result};
use crate::protocol::constants::*;
pub(crate) fn build_rpc_frame(seq_no: i32, payload: &[u8]) -> Vec<u8> {
let total_len = (4 + 4 + payload.len() + 4) as u32;
let mut frame = Vec::with_capacity(total_len as usize);
frame.extend_from_slice(&total_len.to_le_bytes());
frame.extend_from_slice(&seq_no.to_le_bytes());
frame.extend_from_slice(payload);
let c = crc32(&frame);
frame.extend_from_slice(&c.to_le_bytes());
frame
}
pub(crate) async fn read_rpc_frame_plaintext(
rd: &mut (impl AsyncReadExt + Unpin),
) -> Result<(i32, Vec<u8>)> {
let mut len_buf = [0u8; 4];
rd.read_exact(&mut len_buf).await.map_err(ProxyError::Io)?;
let total_len = u32::from_le_bytes(len_buf) as usize;
if !(12..=(1 << 24)).contains(&total_len) {
return Err(ProxyError::InvalidHandshake(format!(
"Bad RPC frame length: {total_len}"
)));
}
let mut rest = vec![0u8; total_len - 4];
rd.read_exact(&mut rest).await.map_err(ProxyError::Io)?;
let mut full = Vec::with_capacity(total_len);
full.extend_from_slice(&len_buf);
full.extend_from_slice(&rest);
let crc_offset = total_len - 4;
let expected_crc = u32::from_le_bytes(full[crc_offset..crc_offset + 4].try_into().unwrap());
let actual_crc = crc32(&full[..crc_offset]);
if expected_crc != actual_crc {
return Err(ProxyError::InvalidHandshake(format!(
"CRC mismatch: 0x{expected_crc:08x} vs 0x{actual_crc:08x}"
)));
}
let seq_no = i32::from_le_bytes(full[4..8].try_into().unwrap());
let payload = full[8..crc_offset].to_vec();
Ok((seq_no, payload))
}
pub(crate) fn build_nonce_payload(key_selector: u32, crypto_ts: u32, nonce: &[u8; 16]) -> [u8; 32] {
let mut p = [0u8; 32];
p[0..4].copy_from_slice(&RPC_NONCE_U32.to_le_bytes());
p[4..8].copy_from_slice(&key_selector.to_le_bytes());
p[8..12].copy_from_slice(&RPC_CRYPTO_AES_U32.to_le_bytes());
p[12..16].copy_from_slice(&crypto_ts.to_le_bytes());
p[16..32].copy_from_slice(nonce);
p
}
pub(crate) fn parse_nonce_payload(d: &[u8]) -> Result<(u32, u32, [u8; 16])> {
if d.len() < 32 {
return Err(ProxyError::InvalidHandshake(format!(
"Nonce payload too short: {} bytes",
d.len()
)));
}
let t = u32::from_le_bytes(d[0..4].try_into().unwrap());
if t != RPC_NONCE_U32 {
return Err(ProxyError::InvalidHandshake(format!(
"Expected RPC_NONCE 0x{RPC_NONCE_U32:08x}, got 0x{t:08x}"
)));
}
let schema = u32::from_le_bytes(d[8..12].try_into().unwrap());
let ts = u32::from_le_bytes(d[12..16].try_into().unwrap());
let mut nonce = [0u8; 16];
nonce.copy_from_slice(&d[16..32]);
Ok((schema, ts, nonce))
}
pub(crate) fn build_handshake_payload(
our_ip: [u8; 4],
our_port: u16,
peer_ip: [u8; 4],
peer_port: u16,
) -> [u8; 32] {
let mut p = [0u8; 32];
p[0..4].copy_from_slice(&RPC_HANDSHAKE_U32.to_le_bytes());
// Keep C memory layout compatibility for PID IPv4 bytes.
p[8..12].copy_from_slice(&our_ip);
p[12..14].copy_from_slice(&our_port.to_le_bytes());
let pid = (std::process::id() & 0xffff) as u16;
p[14..16].copy_from_slice(&pid.to_le_bytes());
let utime = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as u32;
p[16..20].copy_from_slice(&utime.to_le_bytes());
p[20..24].copy_from_slice(&peer_ip);
p[24..26].copy_from_slice(&peer_port.to_le_bytes());
p
}
pub(crate) fn cbc_encrypt_padded(
key: &[u8; 32],
iv: &[u8; 16],
plaintext: &[u8],
) -> Result<(Vec<u8>, [u8; 16])> {
let pad = (16 - (plaintext.len() % 16)) % 16;
let mut buf = plaintext.to_vec();
let pad_pattern: [u8; 4] = [0x04, 0x00, 0x00, 0x00];
for i in 0..pad {
buf.push(pad_pattern[i % 4]);
}
let cipher = AesCbc::new(*key, *iv);
cipher
.encrypt_in_place(&mut buf)
.map_err(|e| ProxyError::Crypto(format!("CBC encrypt: {e}")))?;
let mut new_iv = [0u8; 16];
if buf.len() >= 16 {
new_iv.copy_from_slice(&buf[buf.len() - 16..]);
}
Ok((buf, new_iv))
}
pub(crate) fn cbc_decrypt_inplace(
key: &[u8; 32],
iv: &[u8; 16],
data: &mut [u8],
) -> Result<[u8; 16]> {
let mut new_iv = [0u8; 16];
if data.len() >= 16 {
new_iv.copy_from_slice(&data[data.len() - 16..]);
}
AesCbc::new(*key, *iv)
.decrypt_in_place(data)
.map_err(|e| ProxyError::Crypto(format!("CBC decrypt: {e}")))?;
Ok(new_iv)
}
pub(crate) struct RpcWriter {
pub(crate) writer: tokio::io::WriteHalf<tokio::net::TcpStream>,
pub(crate) key: [u8; 32],
pub(crate) iv: [u8; 16],
pub(crate) seq_no: i32,
}
impl RpcWriter {
pub(crate) async fn send(&mut self, payload: &[u8]) -> Result<()> {
let frame = build_rpc_frame(self.seq_no, payload);
self.seq_no += 1;
let pad = (16 - (frame.len() % 16)) % 16;
let mut buf = frame;
let pad_pattern: [u8; 4] = [0x04, 0x00, 0x00, 0x00];
for i in 0..pad {
buf.push(pad_pattern[i % 4]);
}
let cipher = AesCbc::new(self.key, self.iv);
cipher
.encrypt_in_place(&mut buf)
.map_err(|e| ProxyError::Crypto(format!("{e}")))?;
if buf.len() >= 16 {
self.iv.copy_from_slice(&buf[buf.len() - 16..]);
}
self.writer.write_all(&buf).await.map_err(ProxyError::Io)
}
}

View File

@@ -0,0 +1,38 @@
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, info, warn};
use crate::crypto::SecureRandom;
use crate::protocol::constants::TG_MIDDLE_PROXIES_FLAT_V4;
use super::MePool;
pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, min_connections: usize) {
loop {
tokio::time::sleep(Duration::from_secs(30)).await;
let current = pool.connection_count();
if current < min_connections {
warn!(
current,
min = min_connections,
"ME pool below minimum, reconnecting..."
);
let addrs = TG_MIDDLE_PROXIES_FLAT_V4.clone();
for &(ip, port) in addrs.iter() {
let needed = min_connections.saturating_sub(pool.connection_count());
if needed == 0 {
break;
}
for _ in 0..needed {
let addr = SocketAddr::new(ip, port);
match pool.connect_one(addr, &rng).await {
Ok(()) => info!(%addr, "ME reconnected"),
Err(e) => debug!(%addr, error = %e, "ME reconnect failed"),
}
}
}
}
}
}

View File

@@ -0,0 +1,24 @@
//! Middle Proxy RPC transport.
mod codec;
mod health;
mod pool;
mod reader;
mod registry;
mod secret;
mod wire;
use bytes::Bytes;
pub use health::me_health_monitor;
pub use pool::MePool;
pub use registry::ConnRegistry;
pub use secret::fetch_proxy_secret;
pub use wire::proto_flags_for_tag;
#[derive(Debug)]
pub enum MeResponse {
Data { flags: u32, data: Bytes },
Ack(u32),
Close,
}

View File

@@ -0,0 +1,431 @@
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use bytes::BytesMut;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::{Mutex, RwLock};
use tokio::time::{Instant, timeout};
use tracing::{debug, info, warn};
use crate::crypto::{SecureRandom, derive_middleproxy_keys};
use crate::error::{ProxyError, Result};
use crate::protocol::constants::*;
use super::ConnRegistry;
use super::codec::{
RpcWriter, build_handshake_payload, build_nonce_payload, build_rpc_frame, cbc_decrypt_inplace,
cbc_encrypt_padded, parse_nonce_payload, read_rpc_frame_plaintext,
};
use super::reader::reader_loop;
use super::wire::{IpMaterial, build_proxy_req_payload, extract_ip_material};
pub struct MePool {
registry: Arc<ConnRegistry>,
writers: Arc<RwLock<Vec<Arc<Mutex<RpcWriter>>>>>,
rr: AtomicU64,
proxy_tag: Option<Vec<u8>>,
proxy_secret: Vec<u8>,
nat_ip: Option<IpAddr>,
pool_size: usize,
}
impl MePool {
pub fn new(
proxy_tag: Option<Vec<u8>>,
proxy_secret: Vec<u8>,
nat_ip: Option<IpAddr>,
) -> Arc<Self> {
Arc::new(Self {
registry: Arc::new(ConnRegistry::new()),
writers: Arc::new(RwLock::new(Vec::new())),
rr: AtomicU64::new(0),
proxy_tag,
proxy_secret,
nat_ip,
pool_size: 2,
})
}
pub fn has_proxy_tag(&self) -> bool {
self.proxy_tag.is_some()
}
pub fn translate_our_addr(&self, addr: SocketAddr) -> SocketAddr {
let ip = self.translate_ip_for_nat(addr.ip());
SocketAddr::new(ip, addr.port())
}
pub fn registry(&self) -> &Arc<ConnRegistry> {
&self.registry
}
fn translate_ip_for_nat(&self, ip: IpAddr) -> IpAddr {
let Some(nat_ip) = self.nat_ip else {
return ip;
};
match (ip, nat_ip) {
(IpAddr::V4(src), IpAddr::V4(dst))
if src.is_private() || src.is_loopback() || src.is_unspecified() =>
{
IpAddr::V4(dst)
}
(IpAddr::V6(src), IpAddr::V6(dst)) if src.is_loopback() || src.is_unspecified() => {
IpAddr::V6(dst)
}
(orig, _) => orig,
}
}
fn writers_arc(&self) -> Arc<RwLock<Vec<Arc<Mutex<RpcWriter>>>>> {
self.writers.clone()
}
fn key_selector(&self) -> u32 {
if self.proxy_secret.len() >= 4 {
u32::from_le_bytes([
self.proxy_secret[0],
self.proxy_secret[1],
self.proxy_secret[2],
self.proxy_secret[3],
])
} else {
0
}
}
pub async fn init(self: &Arc<Self>, pool_size: usize, rng: &SecureRandom) -> Result<()> {
let addrs = &*TG_MIDDLE_PROXIES_FLAT_V4;
let ks = self.key_selector();
info!(
me_servers = addrs.len(),
pool_size,
key_selector = format_args!("0x{ks:08x}"),
secret_len = self.proxy_secret.len(),
"Initializing ME pool"
);
for &(ip, port) in addrs.iter() {
for i in 0..pool_size {
let addr = SocketAddr::new(ip, port);
match self.connect_one(addr, rng).await {
Ok(()) => info!(%addr, idx = i, "ME connected"),
Err(e) => warn!(%addr, idx = i, error = %e, "ME connect failed"),
}
}
if self.writers.read().await.len() >= pool_size {
break;
}
}
if self.writers.read().await.is_empty() {
return Err(ProxyError::Proxy("No ME connections".into()));
}
Ok(())
}
pub(crate) async fn connect_one(
self: &Arc<Self>,
addr: SocketAddr,
rng: &SecureRandom,
) -> Result<()> {
let secret = &self.proxy_secret;
if secret.len() < 32 {
return Err(ProxyError::Proxy(
"proxy-secret too short for ME auth".into(),
));
}
let stream = timeout(
Duration::from_secs(ME_CONNECT_TIMEOUT_SECS),
TcpStream::connect(addr),
)
.await
.map_err(|_| ProxyError::ConnectionTimeout {
addr: addr.to_string(),
})?
.map_err(ProxyError::Io)?;
stream.set_nodelay(true).ok();
let local_addr = stream.local_addr().map_err(ProxyError::Io)?;
let peer_addr = stream.peer_addr().map_err(ProxyError::Io)?;
let local_addr_nat = self.translate_our_addr(local_addr);
let peer_addr_nat =
SocketAddr::new(self.translate_ip_for_nat(peer_addr.ip()), peer_addr.port());
let (mut rd, mut wr) = tokio::io::split(stream);
let my_nonce: [u8; 16] = rng.bytes(16).try_into().unwrap();
let crypto_ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as u32;
let ks = self.key_selector();
let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce);
let nonce_frame = build_rpc_frame(-2, &nonce_payload);
wr.write_all(&nonce_frame).await.map_err(ProxyError::Io)?;
wr.flush().await.map_err(ProxyError::Io)?;
let (srv_seq, srv_nonce_payload) = timeout(
Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS),
read_rpc_frame_plaintext(&mut rd),
)
.await
.map_err(|_| ProxyError::TgHandshakeTimeout)??;
if srv_seq != -2 {
return Err(ProxyError::InvalidHandshake(format!(
"Expected seq=-2, got {srv_seq}"
)));
}
let (schema, srv_ts, srv_nonce) = parse_nonce_payload(&srv_nonce_payload)?;
if schema != RPC_CRYPTO_AES_U32 {
return Err(ProxyError::InvalidHandshake(format!(
"Unsupported crypto schema: 0x{schema:x}"
)));
}
let skew = crypto_ts.abs_diff(srv_ts);
if skew > 30 {
return Err(ProxyError::InvalidHandshake(format!(
"nonce crypto_ts skew too large: client={crypto_ts}, server={srv_ts}, skew={skew}s"
)));
}
let ts_bytes = crypto_ts.to_le_bytes();
let server_port_bytes = peer_addr_nat.port().to_le_bytes();
let client_port_bytes = local_addr_nat.port().to_le_bytes();
let server_ip = extract_ip_material(peer_addr_nat);
let client_ip = extract_ip_material(local_addr_nat);
let (srv_ip_opt, clt_ip_opt, clt_v6_opt, srv_v6_opt, hs_our_ip, hs_peer_ip) =
match (server_ip, client_ip) {
(IpMaterial::V4(srv), IpMaterial::V4(clt)) => {
(Some(srv), Some(clt), None, None, clt, srv)
}
(IpMaterial::V6(srv), IpMaterial::V6(clt)) => {
let zero = [0u8; 4];
(None, None, Some(clt), Some(srv), zero, zero)
}
_ => {
return Err(ProxyError::InvalidHandshake(
"mixed IPv4/IPv6 endpoints are not supported for ME key derivation"
.to_string(),
));
}
};
let (wk, wi) = derive_middleproxy_keys(
&srv_nonce,
&my_nonce,
&ts_bytes,
srv_ip_opt.as_ref().map(|x| &x[..]),
&client_port_bytes,
b"CLIENT",
clt_ip_opt.as_ref().map(|x| &x[..]),
&server_port_bytes,
secret,
clt_v6_opt.as_ref(),
srv_v6_opt.as_ref(),
);
let (rk, ri) = derive_middleproxy_keys(
&srv_nonce,
&my_nonce,
&ts_bytes,
srv_ip_opt.as_ref().map(|x| &x[..]),
&client_port_bytes,
b"SERVER",
clt_ip_opt.as_ref().map(|x| &x[..]),
&server_port_bytes,
secret,
clt_v6_opt.as_ref(),
srv_v6_opt.as_ref(),
);
let hs_payload =
build_handshake_payload(hs_our_ip, local_addr.port(), hs_peer_ip, peer_addr.port());
let hs_frame = build_rpc_frame(-1, &hs_payload);
let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?;
wr.write_all(&encrypted_hs).await.map_err(ProxyError::Io)?;
wr.flush().await.map_err(ProxyError::Io)?;
let deadline = Instant::now() + Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS);
let mut enc_buf = BytesMut::with_capacity(256);
let mut dec_buf = BytesMut::with_capacity(256);
let mut read_iv = ri;
let mut handshake_ok = false;
while Instant::now() < deadline && !handshake_ok {
let remaining = deadline - Instant::now();
let mut tmp = [0u8; 256];
let n = match timeout(remaining, rd.read(&mut tmp)).await {
Ok(Ok(0)) => {
return Err(ProxyError::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"ME closed during handshake",
)));
}
Ok(Ok(n)) => n,
Ok(Err(e)) => return Err(ProxyError::Io(e)),
Err(_) => return Err(ProxyError::TgHandshakeTimeout),
};
enc_buf.extend_from_slice(&tmp[..n]);
let blocks = enc_buf.len() / 16 * 16;
if blocks > 0 {
let mut chunk = vec![0u8; blocks];
chunk.copy_from_slice(&enc_buf[..blocks]);
read_iv = cbc_decrypt_inplace(&rk, &read_iv, &mut chunk)?;
dec_buf.extend_from_slice(&chunk);
let _ = enc_buf.split_to(blocks);
}
while dec_buf.len() >= 4 {
let fl = u32::from_le_bytes(dec_buf[0..4].try_into().unwrap()) as usize;
if fl == 4 {
let _ = dec_buf.split_to(4);
continue;
}
if !(12..=(1 << 24)).contains(&fl) {
return Err(ProxyError::InvalidHandshake(format!(
"Bad HS response frame len: {fl}"
)));
}
if dec_buf.len() < fl {
break;
}
let frame = dec_buf.split_to(fl);
let pe = fl - 4;
let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap());
let ac = crate::crypto::crc32(&frame[..pe]);
if ec != ac {
return Err(ProxyError::InvalidHandshake(format!(
"HS CRC mismatch: 0x{ec:08x} vs 0x{ac:08x}"
)));
}
let hs_type = u32::from_le_bytes(frame[8..12].try_into().unwrap());
if hs_type == RPC_HANDSHAKE_ERROR_U32 {
let err_code = if frame.len() >= 16 {
i32::from_le_bytes(frame[12..16].try_into().unwrap())
} else {
-1
};
return Err(ProxyError::InvalidHandshake(format!(
"ME rejected handshake (error={err_code})"
)));
}
if hs_type != RPC_HANDSHAKE_U32 {
return Err(ProxyError::InvalidHandshake(format!(
"Expected HANDSHAKE 0x{RPC_HANDSHAKE_U32:08x}, got 0x{hs_type:08x}"
)));
}
handshake_ok = true;
break;
}
}
if !handshake_ok {
return Err(ProxyError::TgHandshakeTimeout);
}
info!(%addr, "RPC handshake OK");
let rpc_w = Arc::new(Mutex::new(RpcWriter {
writer: wr,
key: wk,
iv: write_iv,
seq_no: 0,
}));
self.writers.write().await.push(rpc_w.clone());
let reg = self.registry.clone();
let w_pong = rpc_w.clone();
let w_pool = self.writers_arc();
tokio::spawn(async move {
if let Err(e) =
reader_loop(rd, rk, read_iv, reg, enc_buf, dec_buf, w_pong.clone()).await
{
warn!(error = %e, "ME reader ended");
}
let mut ws = w_pool.write().await;
ws.retain(|w| !Arc::ptr_eq(w, &w_pong));
info!(remaining = ws.len(), "Dead ME writer removed from pool");
});
Ok(())
}
pub async fn send_proxy_req(
&self,
conn_id: u64,
client_addr: SocketAddr,
our_addr: SocketAddr,
data: &[u8],
proto_flags: u32,
) -> Result<()> {
let payload = build_proxy_req_payload(
conn_id,
client_addr,
our_addr,
data,
self.proxy_tag.as_deref(),
proto_flags,
);
loop {
let ws = self.writers.read().await;
if ws.is_empty() {
return Err(ProxyError::Proxy("All ME connections dead".into()));
}
let idx = self.rr.fetch_add(1, Ordering::Relaxed) as usize % ws.len();
let w = ws[idx].clone();
drop(ws);
match w.lock().await.send(&payload).await {
Ok(()) => return Ok(()),
Err(e) => {
warn!(error = %e, "ME write failed, removing dead conn");
let mut ws = self.writers.write().await;
ws.retain(|o| !Arc::ptr_eq(o, &w));
if ws.is_empty() {
return Err(ProxyError::Proxy("All ME connections dead".into()));
}
}
}
}
}
pub async fn send_close(&self, conn_id: u64) -> Result<()> {
let ws = self.writers.read().await;
if !ws.is_empty() {
let w = ws[0].clone();
drop(ws);
let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes());
p.extend_from_slice(&conn_id.to_le_bytes());
if let Err(e) = w.lock().await.send(&p).await {
debug!(error = %e, "ME close write failed");
let mut ws = self.writers.write().await;
ws.retain(|o| !Arc::ptr_eq(o, &w));
}
}
self.registry.unregister(conn_id).await;
Ok(())
}
pub fn connection_count(&self) -> usize {
self.writers.try_read().map(|w| w.len()).unwrap_or(0)
}
}

View File

@@ -0,0 +1,141 @@
use std::sync::Arc;
use bytes::{Bytes, BytesMut};
use tokio::io::AsyncReadExt;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tracing::{debug, trace, warn};
use crate::crypto::{AesCbc, crc32};
use crate::error::{ProxyError, Result};
use crate::protocol::constants::*;
use super::codec::RpcWriter;
use super::{ConnRegistry, MeResponse};
pub(crate) async fn reader_loop(
mut rd: tokio::io::ReadHalf<TcpStream>,
dk: [u8; 32],
mut div: [u8; 16],
reg: Arc<ConnRegistry>,
enc_leftover: BytesMut,
mut dec: BytesMut,
writer: Arc<Mutex<RpcWriter>>,
) -> Result<()> {
let mut raw = enc_leftover;
loop {
let mut tmp = [0u8; 16_384];
let n = rd.read(&mut tmp).await.map_err(ProxyError::Io)?;
if n == 0 {
return Ok(());
}
raw.extend_from_slice(&tmp[..n]);
let blocks = raw.len() / 16 * 16;
if blocks > 0 {
let mut new_iv = [0u8; 16];
new_iv.copy_from_slice(&raw[blocks - 16..blocks]);
let mut chunk = vec![0u8; blocks];
chunk.copy_from_slice(&raw[..blocks]);
AesCbc::new(dk, div)
.decrypt_in_place(&mut chunk)
.map_err(|e| ProxyError::Crypto(format!("{e}")))?;
div = new_iv;
dec.extend_from_slice(&chunk);
let _ = raw.split_to(blocks);
}
while dec.len() >= 12 {
let fl = u32::from_le_bytes(dec[0..4].try_into().unwrap()) as usize;
if fl == 4 {
let _ = dec.split_to(4);
continue;
}
if !(12..=(1 << 24)).contains(&fl) {
warn!(frame_len = fl, "Invalid RPC frame len");
dec.clear();
break;
}
if dec.len() < fl {
break;
}
let frame = dec.split_to(fl);
let pe = fl - 4;
let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap());
if crc32(&frame[..pe]) != ec {
warn!("CRC mismatch in data frame");
continue;
}
let payload = &frame[8..pe];
if payload.len() < 4 {
continue;
}
let pt = u32::from_le_bytes(payload[0..4].try_into().unwrap());
let body = &payload[4..];
if pt == RPC_PROXY_ANS_U32 && body.len() >= 12 {
let flags = u32::from_le_bytes(body[0..4].try_into().unwrap());
let cid = u64::from_le_bytes(body[4..12].try_into().unwrap());
let data = Bytes::copy_from_slice(&body[12..]);
trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS");
let routed = reg.route(cid, MeResponse::Data { flags, data }).await;
if !routed {
reg.unregister(cid).await;
send_close_conn(&writer, cid).await;
}
} else if pt == RPC_SIMPLE_ACK_U32 && body.len() >= 12 {
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
let cfm = u32::from_le_bytes(body[8..12].try_into().unwrap());
trace!(cid, cfm, "RPC_SIMPLE_ACK");
let routed = reg.route(cid, MeResponse::Ack(cfm)).await;
if !routed {
reg.unregister(cid).await;
send_close_conn(&writer, cid).await;
}
} else if pt == RPC_CLOSE_EXT_U32 && body.len() >= 8 {
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
debug!(cid, "RPC_CLOSE_EXT from ME");
reg.route(cid, MeResponse::Close).await;
reg.unregister(cid).await;
} else if pt == RPC_CLOSE_CONN_U32 && body.len() >= 8 {
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
debug!(cid, "RPC_CLOSE_CONN from ME");
reg.route(cid, MeResponse::Close).await;
reg.unregister(cid).await;
} else if pt == RPC_PING_U32 && body.len() >= 8 {
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
trace!(ping_id, "RPC_PING -> RPC_PONG");
let mut pong = Vec::with_capacity(12);
pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes());
pong.extend_from_slice(&ping_id.to_le_bytes());
if let Err(e) = writer.lock().await.send(&pong).await {
warn!(error = %e, "PONG send failed");
break;
}
} else {
debug!(
rpc_type = format_args!("0x{pt:08x}"),
len = body.len(),
"Unknown RPC"
);
}
}
}
}
async fn send_close_conn(writer: &Arc<Mutex<RpcWriter>>, conn_id: u64) {
let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes());
p.extend_from_slice(&conn_id.to_le_bytes());
if let Err(e) = writer.lock().await.send(&p).await {
debug!(conn_id, error = %e, "Failed to send RPC_CLOSE_CONN");
}
}

View File

@@ -0,0 +1,40 @@
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::{RwLock, mpsc};
use super::MeResponse;
pub struct ConnRegistry {
map: RwLock<HashMap<u64, mpsc::Sender<MeResponse>>>,
next_id: AtomicU64,
}
impl ConnRegistry {
pub fn new() -> Self {
Self {
map: RwLock::new(HashMap::new()),
next_id: AtomicU64::new(1),
}
}
pub async fn register(&self) -> (u64, mpsc::Receiver<MeResponse>) {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = mpsc::channel(256);
self.map.write().await.insert(id, tx);
(id, rx)
}
pub async fn unregister(&self, id: u64) {
self.map.write().await.remove(&id);
}
pub async fn route(&self, id: u64, resp: MeResponse) -> bool {
let m = self.map.read().await;
if let Some(tx) = m.get(&id) {
tx.send(resp).await.is_ok()
} else {
false
}
}
}

View File

@@ -0,0 +1,76 @@
use std::time::Duration;
use tracing::{debug, info, warn};
use crate::error::{ProxyError, Result};
/// Fetch Telegram proxy-secret binary.
pub async fn fetch_proxy_secret(cache_path: Option<&str>) -> Result<Vec<u8>> {
let cache = cache_path.unwrap_or("proxy-secret");
if let Ok(metadata) = tokio::fs::metadata(cache).await {
if let Ok(modified) = metadata.modified() {
let age = std::time::SystemTime::now()
.duration_since(modified)
.unwrap_or(Duration::from_secs(u64::MAX));
if age < Duration::from_secs(86_400) {
if let Ok(data) = tokio::fs::read(cache).await {
if data.len() >= 32 {
info!(
path = cache,
len = data.len(),
age_hours = age.as_secs() / 3600,
"Loaded proxy-secret from cache"
);
return Ok(data);
}
warn!(
path = cache,
len = data.len(),
"Cached proxy-secret too short"
);
}
}
}
}
info!("Downloading proxy-secret from core.telegram.org...");
let data = download_proxy_secret().await?;
if let Err(e) = tokio::fs::write(cache, &data).await {
warn!(error = %e, "Failed to cache proxy-secret (non-fatal)");
} else {
debug!(path = cache, len = data.len(), "Cached proxy-secret");
}
Ok(data)
}
async fn download_proxy_secret() -> Result<Vec<u8>> {
let resp = reqwest::get("https://core.telegram.org/getProxySecret")
.await
.map_err(|e| ProxyError::Proxy(format!("Failed to download proxy-secret: {e}")))?;
if !resp.status().is_success() {
return Err(ProxyError::Proxy(format!(
"proxy-secret download HTTP {}",
resp.status()
)));
}
let data = resp
.bytes()
.await
.map_err(|e| ProxyError::Proxy(format!("Read proxy-secret body: {e}")))?
.to_vec();
if data.len() < 32 {
return Err(ProxyError::Proxy(format!(
"proxy-secret too short: {} bytes (need >= 32)",
data.len()
)));
}
info!(len = data.len(), "Downloaded proxy-secret OK");
Ok(data)
}

View File

@@ -0,0 +1,106 @@
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use crate::protocol::constants::*;
#[derive(Clone, Copy)]
pub(crate) enum IpMaterial {
V4([u8; 4]),
V6([u8; 16]),
}
pub(crate) fn extract_ip_material(addr: SocketAddr) -> IpMaterial {
match addr.ip() {
IpAddr::V4(v4) => IpMaterial::V4(v4.octets()),
IpAddr::V6(v6) => {
if let Some(v4) = v6.to_ipv4_mapped() {
IpMaterial::V4(v4.octets())
} else {
IpMaterial::V6(v6.octets())
}
}
}
}
fn ipv4_to_mapped_v6_c_compat(ip: Ipv4Addr) -> [u8; 16] {
let mut buf = [0u8; 16];
// Matches tl_store_long(0) + tl_store_int(-0x10000).
buf[8..12].copy_from_slice(&(-0x10000i32).to_le_bytes());
// Matches tl_store_int(htonl(remote_ip_host_order)).
let host_order = u32::from_ne_bytes(ip.octets());
let network_order = host_order.to_be();
buf[12..16].copy_from_slice(&network_order.to_le_bytes());
buf
}
fn append_mapped_addr_and_port(buf: &mut Vec<u8>, addr: SocketAddr) {
match addr.ip() {
IpAddr::V4(v4) => buf.extend_from_slice(&ipv4_to_mapped_v6_c_compat(v4)),
IpAddr::V6(v6) => buf.extend_from_slice(&v6.octets()),
}
buf.extend_from_slice(&(addr.port() as u32).to_le_bytes());
}
pub(crate) fn build_proxy_req_payload(
conn_id: u64,
client_addr: SocketAddr,
our_addr: SocketAddr,
data: &[u8],
proxy_tag: Option<&[u8]>,
proto_flags: u32,
) -> Vec<u8> {
let mut b = Vec::with_capacity(128 + data.len());
b.extend_from_slice(&RPC_PROXY_REQ_U32.to_le_bytes());
b.extend_from_slice(&proto_flags.to_le_bytes());
b.extend_from_slice(&conn_id.to_le_bytes());
append_mapped_addr_and_port(&mut b, client_addr);
append_mapped_addr_and_port(&mut b, our_addr);
if proto_flags & 12 != 0 {
let extra_start = b.len();
b.extend_from_slice(&0u32.to_le_bytes());
if let Some(tag) = proxy_tag {
b.extend_from_slice(&TL_PROXY_TAG_U32.to_le_bytes());
if tag.len() < 254 {
b.push(tag.len() as u8);
b.extend_from_slice(tag);
let pad = (4 - ((1 + tag.len()) % 4)) % 4;
b.extend(std::iter::repeat_n(0u8, pad));
} else {
b.push(0xfe);
let len_bytes = (tag.len() as u32).to_le_bytes();
b.extend_from_slice(&len_bytes[..3]);
b.extend_from_slice(tag);
let pad = (4 - (tag.len() % 4)) % 4;
b.extend(std::iter::repeat_n(0u8, pad));
}
}
let extra_bytes = (b.len() - extra_start - 4) as u32;
b[extra_start..extra_start + 4].copy_from_slice(&extra_bytes.to_le_bytes());
}
b.extend_from_slice(data);
b
}
pub fn proto_flags_for_tag(tag: crate::protocol::constants::ProtoTag, has_proxy_tag: bool) -> u32 {
use crate::protocol::constants::ProtoTag;
let mut flags = RPC_FLAG_MAGIC | RPC_FLAG_EXTMODE2;
if has_proxy_tag {
flags |= RPC_FLAG_HAS_AD_TAG;
}
match tag {
ProtoTag::Abridged => flags | RPC_FLAG_ABRIDGED,
ProtoTag::Intermediate => flags | RPC_FLAG_INTERMEDIATE,
ProtoTag::Secure => flags | RPC_FLAG_PAD | RPC_FLAG_INTERMEDIATE,
}
}

View File

@@ -10,5 +10,5 @@ pub use pool::ConnectionPool;
pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol}; pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol};
pub use socket::*; pub use socket::*;
pub use socks::*; pub use socks::*;
pub use upstream::{UpstreamManager, StartupPingResult, DcPingResult}; pub use upstream::{DcPingResult, StartupPingResult, UpstreamManager};
pub mod middle_proxy; pub mod middle_proxy;