diff --git a/README.md b/README.md index a922f40..f93e5da 100644 --- a/README.md +++ b/README.md @@ -118,44 +118,100 @@ then Ctrl+X -> Y -> Enter to save ## Configuration ### Minimal Configuration for First Start ```toml -port = 443 # Listening port -show_link = ["tele", "hello"] # Specify users, for whom will be displayed the links +# === General Settings === +[general] +prefer_ipv6 = false +fast_mode = true +use_middle_proxy = false +# ad_tag = "..." -tls_domain = "petrovich.ru" # Domain for ee-secret and masking -mask = true # Enable masking of bad traffic -mask_host = "petrovich.ru" # Optional override for mask destination -mask_port = 443 # Port for masking +[general.modes] +classic = false +secure = false +tls = true -prefer_ipv6 = false # Try IPv6 DCs first if true -fast_mode = true # Use "fast" obfuscation variant +# === Server Binding === +[server] +port = 443 +listen_addr_ipv4 = "0.0.0.0" +listen_addr_ipv6 = "::" +# metrics_port = 9090 +# metrics_whitelist = ["127.0.0.1", "::1"] -client_keepalive = 600 # Seconds -client_ack_timeout = 300 # Seconds +# Listen on multiple interfaces/IPs (overrides listen_addr_*) +[[server.listeners]] +ip = "0.0.0.0" +# announce_ip = "1.2.3.4" # Optional: Public IP for tg:// links -[modes] -classic = true # Plain obfuscated mode -secure = true # dd-prefix mode -tls = true # Fake TLS (ee-prefix) +[[server.listeners]] +ip = "::" -[users] -hello = "00000000000000000000000000000000" # Replace the secret with one generated before -tele = "00000000000000000000000000000000" # Replace the secret with one generated before +# === Timeouts (in seconds) === +[timeouts] +client_handshake = 15 +tg_connect = 10 +client_keepalive = 60 +client_ack = 300 + +# === Anti-Censorship & Masking === +[censorship] +tls_domain = "petrovich.ru" +mask = true +mask_port = 443 +# mask_host = "petrovich.ru" # Defaults to tls_domain if not set +fake_cert_len = 2048 + +# === Access Control & Users === +# username "hello" is used for example +[access] +replay_check_len = 65536 +ignore_time_skew = false + +[access.users] +# format: "username" = "32_hex_chars_secret" +hello = "00000000000000000000000000000000" + +# [access.user_max_tcp_conns] +# hello = 50 + +# [access.user_data_quota] +# hello = 1073741824 # 1 GB + +# === Upstreams & Routing === +# By default, direct connection is used, but you can add SOCKS proxy + +# Direct - Default +[[upstreams]] +type = "direct" +enabled = true +weight = 10 + +# SOCKS5 +# [[upstreams]] +# type = "socks5" +# address = "127.0.0.1:9050" +# enabled = false +# weight = 1 + +# === UI === +# Users to show in the startup log (tg:// links) +show_link = ["hello"] ``` ### Advanced #### Adtag -To use channel advertising and usage statistics from Telegram, get Adtag from [@mtproxybot](https://t.me/mtproxybot), add this parameter to the end of config.toml and specify it +To use channel advertising and usage statistics from Telegram, get Adtag from [@mtproxybot](https://t.me/mtproxybot), add this parameter to section `[General]` ```toml ad_tag = "00000000000000000000000000000000" # Replace zeros to your adtag from @mtproxybot ``` #### Listening and Announce IPs -To specify listening address and/or address in links, add to the end of config.toml: +To specify listening address and/or address in links, add to section `[[server.listeners]]` of config.toml: ```toml -[[listeners]] +[[server.listeners]] ip = "0.0.0.0" # 0.0.0.0 = all IPs; your IP = specific listening announce_ip = "1.2.3.4" # IP in links; comment with # if not used ``` #### Upstream Manager -To specify upstream, add to the end of config.toml: +To specify upstream, add to section `[[upstreams]]` of config.toml: ##### Bind on IP ```toml [[upstreams]] diff --git a/config.toml b/config.toml index b8a62af..73b9b06 100644 --- a/config.toml +++ b/config.toml @@ -1,13 +1,78 @@ -port = 443 +# === General Settings === +[general] +prefer_ipv6 = false +fast_mode = true +use_middle_proxy = false +# ad_tag = "..." -[users] -user1 = "00000000000000000000000000000000" - -[modes] -classic = true -secure = true +[general.modes] +classic = false +secure = false tls = true -tls_domain = "www.github.com" -fast_mode = true -prefer_ipv6 = false \ No newline at end of file +# === Server Binding === +[server] +port = 443 +listen_addr_ipv4 = "0.0.0.0" +listen_addr_ipv6 = "::" +# metrics_port = 9090 +# metrics_whitelist = ["127.0.0.1", "::1"] + +# Listen on multiple interfaces/IPs (overrides listen_addr_*) +[[server.listeners]] +ip = "0.0.0.0" +# announce_ip = "1.2.3.4" # Optional: Public IP for tg:// links + +[[server.listeners]] +ip = "::" + +# === Timeouts (in seconds) === +[timeouts] +client_handshake = 15 +tg_connect = 10 +client_keepalive = 60 +client_ack = 300 + +# === Anti-Censorship & Masking === +[censorship] +tls_domain = "petrovich.ru" +mask = true +mask_port = 443 +# mask_host = "petrovich.ru" # Defaults to tls_domain if not set +fake_cert_len = 2048 + +# === Access Control & Users === +# username "hello" is used for example +[access] +replay_check_len = 65536 +ignore_time_skew = false + +[access.users] +# format: "username" = "32_hex_chars_secret" +hello = "00000000000000000000000000000000" + +# [access.user_max_tcp_conns] +# hello = 50 + +# [access.user_data_quota] +# hello = 1073741824 # 1 GB + +# === Upstreams & Routing === +# By default, direct connection is used, but you can add SOCKS proxy + +# Direct - Default +[[upstreams]] +type = "direct" +enabled = true +weight = 10 + +# SOCKS5 +# [[upstreams]] +# type = "socks5" +# address = "127.0.0.1:9050" +# enabled = false +# weight = 1 + +# === UI === +# Users to show in the startup log (tg:// links) +show_link = ["hello"] \ No newline at end of file diff --git a/src/config/mod.rs b/src/config/mod.rs index bbe3f61..425aeef 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -7,6 +7,29 @@ use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use crate::error::{ProxyError, Result}; +// ============= Helper Defaults ============= + +fn default_true() -> bool { true } +fn default_port() -> u16 { 443 } +fn default_tls_domain() -> String { "www.google.com".to_string() } +fn default_mask_port() -> u16 { 443 } +fn default_replay_check_len() -> usize { 65536 } +fn default_handshake_timeout() -> u64 { 15 } +fn default_connect_timeout() -> u64 { 10 } +fn default_keepalive() -> u64 { 60 } +fn default_ack_timeout() -> u64 { 300 } +fn default_listen_addr() -> String { "0.0.0.0".to_string() } +fn default_fake_cert_len() -> usize { 2048 } +fn default_weight() -> u16 { 1 } +fn default_metrics_whitelist() -> Vec { + vec![ + "127.0.0.1".parse().unwrap(), + "::1".parse().unwrap(), + ] +} + +// ============= Sub-Configs ============= + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ProxyModes { #[serde(default)] @@ -17,26 +40,185 @@ pub struct ProxyModes { pub tls: bool, } -fn default_true() -> bool { true } -fn default_weight() -> u16 { 1 } - impl Default for ProxyModes { fn default() -> Self { Self { classic: true, secure: true, tls: true } } } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GeneralConfig { + #[serde(default)] + pub modes: ProxyModes, + + #[serde(default)] + pub prefer_ipv6: bool, + + #[serde(default = "default_true")] + pub fast_mode: bool, + + #[serde(default)] + pub use_middle_proxy: bool, + + #[serde(default)] + pub ad_tag: Option, +} + +impl Default for GeneralConfig { + fn default() -> Self { + Self { + modes: ProxyModes::default(), + prefer_ipv6: false, + fast_mode: true, + use_middle_proxy: false, + ad_tag: None, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ServerConfig { + #[serde(default = "default_port")] + pub port: u16, + + #[serde(default = "default_listen_addr")] + pub listen_addr_ipv4: String, + + #[serde(default)] + pub listen_addr_ipv6: Option, + + #[serde(default)] + pub listen_unix_sock: Option, + + #[serde(default)] + pub metrics_port: Option, + + #[serde(default = "default_metrics_whitelist")] + pub metrics_whitelist: Vec, + + #[serde(default)] + pub listeners: Vec, +} + +impl Default for ServerConfig { + fn default() -> Self { + Self { + port: default_port(), + listen_addr_ipv4: default_listen_addr(), + listen_addr_ipv6: Some("::".to_string()), + listen_unix_sock: None, + metrics_port: None, + metrics_whitelist: default_metrics_whitelist(), + listeners: Vec::new(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TimeoutsConfig { + #[serde(default = "default_handshake_timeout")] + pub client_handshake: u64, + + #[serde(default = "default_connect_timeout")] + pub tg_connect: u64, + + #[serde(default = "default_keepalive")] + pub client_keepalive: u64, + + #[serde(default = "default_ack_timeout")] + pub client_ack: u64, +} + +impl Default for TimeoutsConfig { + fn default() -> Self { + Self { + client_handshake: default_handshake_timeout(), + tg_connect: default_connect_timeout(), + client_keepalive: default_keepalive(), + client_ack: default_ack_timeout(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AntiCensorshipConfig { + #[serde(default = "default_tls_domain")] + pub tls_domain: String, + + #[serde(default = "default_true")] + pub mask: bool, + + #[serde(default)] + pub mask_host: Option, + + #[serde(default = "default_mask_port")] + pub mask_port: u16, + + #[serde(default = "default_fake_cert_len")] + pub fake_cert_len: usize, +} + +impl Default for AntiCensorshipConfig { + fn default() -> Self { + Self { + tls_domain: default_tls_domain(), + mask: true, + mask_host: None, + mask_port: default_mask_port(), + fake_cert_len: default_fake_cert_len(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AccessConfig { + #[serde(default)] + pub users: HashMap, + + #[serde(default)] + pub user_max_tcp_conns: HashMap, + + #[serde(default)] + pub user_expirations: HashMap>, + + #[serde(default)] + pub user_data_quota: HashMap, + + #[serde(default = "default_replay_check_len")] + pub replay_check_len: usize, + + #[serde(default)] + pub ignore_time_skew: bool, +} + +impl Default for AccessConfig { + fn default() -> Self { + let mut users = HashMap::new(); + users.insert("default".to_string(), "00000000000000000000000000000000".to_string()); + Self { + users, + user_max_tcp_conns: HashMap::new(), + user_expirations: HashMap::new(), + user_data_quota: HashMap::new(), + replay_check_len: default_replay_check_len(), + ignore_time_skew: false, + } + } +} + +// ============= Aux Structures ============= + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(tag = "type", rename_all = "lowercase")] pub enum UpstreamType { Direct { #[serde(default)] - interface: Option, // Bind to specific IP/Interface + interface: Option, }, Socks4 { - address: String, // IP:Port of SOCKS server + address: String, #[serde(default)] - interface: Option, // Bind to specific IP/Interface for connection to SOCKS + interface: Option, #[serde(default)] user_id: Option, }, @@ -65,160 +247,35 @@ pub struct UpstreamConfig { pub struct ListenerConfig { pub ip: IpAddr, #[serde(default)] - pub announce_ip: Option, // IP to show in tg:// links + pub announce_ip: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ProxyConfig { - #[serde(default = "default_port")] - pub port: u16, - - #[serde(default)] - pub users: HashMap, - - #[serde(default)] - pub ad_tag: Option, - - #[serde(default)] - pub modes: ProxyModes, - - #[serde(default = "default_tls_domain")] - pub tls_domain: String, - - #[serde(default = "default_true")] - pub mask: bool, - - #[serde(default)] - pub mask_host: Option, - - #[serde(default = "default_mask_port")] - pub mask_port: u16, - - #[serde(default)] - pub prefer_ipv6: bool, - - #[serde(default = "default_true")] - pub fast_mode: bool, - - #[serde(default)] - pub use_middle_proxy: bool, - - #[serde(default)] - pub user_max_tcp_conns: HashMap, - - #[serde(default)] - pub user_expirations: HashMap>, - - #[serde(default)] - pub user_data_quota: HashMap, - - #[serde(default = "default_replay_check_len")] - pub replay_check_len: usize, - - #[serde(default)] - pub ignore_time_skew: bool, - - #[serde(default = "default_handshake_timeout")] - pub client_handshake_timeout: u64, - - #[serde(default = "default_connect_timeout")] - pub tg_connect_timeout: u64, - - #[serde(default = "default_keepalive")] - pub client_keepalive: u64, - - #[serde(default = "default_ack_timeout")] - pub client_ack_timeout: u64, - - #[serde(default = "default_listen_addr")] - pub listen_addr_ipv4: String, - - #[serde(default)] - pub listen_addr_ipv6: Option, - - #[serde(default)] - pub listen_unix_sock: Option, - - #[serde(default)] - pub metrics_port: Option, - - #[serde(default = "default_metrics_whitelist")] - pub metrics_whitelist: Vec, - - #[serde(default = "default_fake_cert_len")] - pub fake_cert_len: usize, +// ============= Main Config ============= + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ProxyConfig { + #[serde(default)] + pub general: GeneralConfig, + + #[serde(default)] + pub server: ServerConfig, + + #[serde(default)] + pub timeouts: TimeoutsConfig, + + #[serde(default)] + pub censorship: AntiCensorshipConfig, + + #[serde(default)] + pub access: AccessConfig, - // New fields #[serde(default)] pub upstreams: Vec, - #[serde(default)] - pub listeners: Vec, - #[serde(default)] pub show_link: Vec, } -fn default_port() -> u16 { 443 } -fn default_tls_domain() -> String { "www.google.com".to_string() } -fn default_mask_port() -> u16 { 443 } -fn default_replay_check_len() -> usize { 65536 } -// CHANGED: Increased handshake timeout for bad mobile networks -fn default_handshake_timeout() -> u64 { 15 } -fn default_connect_timeout() -> u64 { 10 } -// CHANGED: Reduced keepalive from 600s to 60s. -// Mobile NATs often drop idle connections after 60-120s. -fn default_keepalive() -> u64 { 60 } -fn default_ack_timeout() -> u64 { 300 } -fn default_listen_addr() -> String { "0.0.0.0".to_string() } -fn default_fake_cert_len() -> usize { 2048 } - -fn default_metrics_whitelist() -> Vec { - vec![ - "127.0.0.1".parse().unwrap(), - "::1".parse().unwrap(), - ] -} - -impl Default for ProxyConfig { - fn default() -> Self { - let mut users = HashMap::new(); - users.insert("default".to_string(), "00000000000000000000000000000000".to_string()); - - Self { - port: default_port(), - users, - ad_tag: None, - modes: ProxyModes::default(), - tls_domain: default_tls_domain(), - mask: true, - mask_host: None, - mask_port: default_mask_port(), - prefer_ipv6: false, - fast_mode: true, - use_middle_proxy: false, - user_max_tcp_conns: HashMap::new(), - user_expirations: HashMap::new(), - user_data_quota: HashMap::new(), - replay_check_len: default_replay_check_len(), - ignore_time_skew: false, - client_handshake_timeout: default_handshake_timeout(), - tg_connect_timeout: default_connect_timeout(), - client_keepalive: default_keepalive(), - client_ack_timeout: default_ack_timeout(), - listen_addr_ipv4: default_listen_addr(), - listen_addr_ipv6: Some("::".to_string()), - listen_unix_sock: None, - metrics_port: None, - metrics_whitelist: default_metrics_whitelist(), - fake_cert_len: default_fake_cert_len(), - upstreams: Vec::new(), - listeners: Vec::new(), - show_link: Vec::new(), - } - } -} - impl ProxyConfig { pub fn load>(path: P) -> Result { let content = std::fs::read_to_string(path) @@ -228,7 +285,7 @@ impl ProxyConfig { .map_err(|e| ProxyError::Config(e.to_string()))?; // Validate secrets - for (user, secret) in &config.users { + for (user, secret) in &config.access.users { if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 { return Err(ProxyError::InvalidSecret { user: user.clone(), @@ -237,26 +294,37 @@ impl ProxyConfig { } } - // Default mask_host - if config.mask_host.is_none() { - config.mask_host = Some(config.tls_domain.clone()); + // Validate tls_domain + if config.censorship.tls_domain.is_empty() { + return Err(ProxyError::Config("tls_domain cannot be empty".to_string())); + } + + // Warn if using default tls_domain + if config.censorship.tls_domain == "www.google.com" { + tracing::warn!("Using default tls_domain (www.google.com). Consider setting a custom domain in config.toml"); + } + + // Default mask_host to tls_domain if not set + if config.censorship.mask_host.is_none() { + tracing::info!("mask_host not set, using tls_domain ({}) for masking", config.censorship.tls_domain); + config.censorship.mask_host = Some(config.censorship.tls_domain.clone()); } // Random fake_cert_len use rand::Rng; - config.fake_cert_len = rand::thread_rng().gen_range(1024..4096); + config.censorship.fake_cert_len = rand::thread_rng().gen_range(1024..4096); // Migration: Populate listeners if empty - if config.listeners.is_empty() { - if let Ok(ipv4) = config.listen_addr_ipv4.parse::() { - config.listeners.push(ListenerConfig { + if config.server.listeners.is_empty() { + if let Ok(ipv4) = config.server.listen_addr_ipv4.parse::() { + config.server.listeners.push(ListenerConfig { ip: ipv4, announce_ip: None, }); } - if let Some(ipv6_str) = &config.listen_addr_ipv6 { + if let Some(ipv6_str) = &config.server.listen_addr_ipv6 { if let Ok(ipv6) = ipv6_str.parse::() { - config.listeners.push(ListenerConfig { + config.server.listeners.push(ListenerConfig { ip: ipv6, announce_ip: None, }); @@ -277,14 +345,21 @@ impl ProxyConfig { } pub fn validate(&self) -> Result<()> { - if self.users.is_empty() { + if self.access.users.is_empty() { return Err(ProxyError::Config("No users configured".to_string())); } - if !self.modes.classic && !self.modes.secure && !self.modes.tls { + if !self.general.modes.classic && !self.general.modes.secure && !self.general.modes.tls { return Err(ProxyError::Config("No modes enabled".to_string())); } + // Validate tls_domain format (basic check) + if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') { + return Err(ProxyError::Config( + format!("Invalid tls_domain: '{}'. Must be a valid domain name", self.censorship.tls_domain) + )); + } + Ok(()) } } \ No newline at end of file diff --git a/src/error.rs b/src/error.rs index d20b8d8..6d9bae3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -297,16 +297,16 @@ pub type StreamResult = std::result::Result; /// Result with optional bad client handling #[derive(Debug)] -pub enum HandshakeResult { +pub enum HandshakeResult { /// Handshake succeeded Success(T), - /// Client failed validation, needs masking - BadClient, + /// Client failed validation, needs masking. Returns ownership of streams. + BadClient { reader: R, writer: W }, /// Error occurred Error(ProxyError), } -impl HandshakeResult { +impl HandshakeResult { /// Check if successful pub fn is_success(&self) -> bool { matches!(self, HandshakeResult::Success(_)) @@ -314,49 +314,32 @@ impl HandshakeResult { /// Check if bad client pub fn is_bad_client(&self) -> bool { - matches!(self, HandshakeResult::BadClient) - } - - /// Convert to Result, treating BadClient as error - pub fn into_result(self) -> Result { - match self { - HandshakeResult::Success(v) => Ok(v), - HandshakeResult::BadClient => Err(ProxyError::InvalidHandshake("Bad client".into())), - HandshakeResult::Error(e) => Err(e), - } + matches!(self, HandshakeResult::BadClient { .. }) } /// Map the success value - pub fn map U>(self, f: F) -> HandshakeResult { + pub fn map U>(self, f: F) -> HandshakeResult { match self { HandshakeResult::Success(v) => HandshakeResult::Success(f(v)), - HandshakeResult::BadClient => HandshakeResult::BadClient, + HandshakeResult::BadClient { reader, writer } => HandshakeResult::BadClient { reader, writer }, HandshakeResult::Error(e) => HandshakeResult::Error(e), } } - - /// Convert success to Option - pub fn ok(self) -> Option { - match self { - HandshakeResult::Success(v) => Some(v), - _ => None, - } - } } -impl From for HandshakeResult { +impl From for HandshakeResult { fn from(err: ProxyError) -> Self { HandshakeResult::Error(err) } } -impl From for HandshakeResult { +impl From for HandshakeResult { fn from(err: std::io::Error) -> Self { HandshakeResult::Error(ProxyError::Io(err)) } } -impl From for HandshakeResult { +impl From for HandshakeResult { fn from(err: StreamError) -> Self { HandshakeResult::Error(ProxyError::Stream(err)) } diff --git a/src/main.rs b/src/main.rs index fde2a59..a672ed0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,6 +23,7 @@ use crate::proxy::ClientHandler; use crate::stats::{Stats, ReplayChecker}; use crate::transport::{create_listener, ListenOptions, UpstreamManager}; use crate::util::ip::detect_ip; +use crate::stream::BufferPool; #[tokio::main] async fn main() -> Result<(), Box> { @@ -52,15 +53,33 @@ async fn main() -> Result<(), Box> { config.validate()?; + // Log loaded configuration for debugging + info!("=== Configuration Loaded ==="); + info!("TLS Domain: {}", config.censorship.tls_domain); + info!("Mask enabled: {}", config.censorship.mask); + info!("Mask host: {}", config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain)); + info!("Mask port: {}", config.censorship.mask_port); + info!("Modes: classic={}, secure={}, tls={}", + config.general.modes.classic, + config.general.modes.secure, + config.general.modes.tls + ); + info!("============================"); + let config = Arc::new(config); let stats = Arc::new(Stats::new()); - // CHANGED: Initialize global ReplayChecker here instead of per-connection - let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len)); + // Initialize global ReplayChecker + // Using sharded implementation for better concurrency + let replay_checker = Arc::new(ReplayChecker::new(config.access.replay_check_len)); // Initialize Upstream Manager let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone())); + // Initialize Buffer Pool + // 16KB buffers, max 4096 buffers (~64MB total cached) + let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096)); + // Start Health Checks let um_clone = upstream_manager.clone(); tokio::spawn(async move { @@ -73,8 +92,8 @@ async fn main() -> Result<(), Box> { // Start Listeners let mut listeners = Vec::new(); - for listener_conf in &config.listeners { - let addr = SocketAddr::new(listener_conf.ip, config.port); + for listener_conf in &config.server.listeners { + let addr = SocketAddr::new(listener_conf.ip, config.server.port); let options = ListenOptions { ipv6_only: listener_conf.ip.is_ipv6(), ..Default::default() @@ -86,13 +105,9 @@ async fn main() -> Result<(), Box> { info!("Listening on {}", addr); // Determine public IP for tg:// links - // 1. Use explicit announce_ip if set - // 2. If listening on 0.0.0.0 or ::, use detected public IP - // 3. Otherwise use the bind IP let public_ip = if let Some(ip) = listener_conf.announce_ip { ip } else if listener_conf.ip.is_unspecified() { - // Try to use detected IP of the same family if listener_conf.ip.is_ipv4() { detected_ip.ipv4.unwrap_or(listener_conf.ip) } else { @@ -106,26 +121,23 @@ async fn main() -> Result<(), Box> { if !config.show_link.is_empty() { info!("--- Proxy Links for {} ---", public_ip); for user_name in &config.show_link { - if let Some(secret) = config.users.get(user_name) { + if let Some(secret) = config.access.users.get(user_name) { info!("User: {}", user_name); - // Classic - if config.modes.classic { + if config.general.modes.classic { info!(" Classic: tg://proxy?server={}&port={}&secret={}", - public_ip, config.port, secret); + public_ip, config.server.port, secret); } - // DD (Secure) - if config.modes.secure { + if config.general.modes.secure { info!(" DD: tg://proxy?server={}&port={}&secret=dd{}", - public_ip, config.port, secret); + public_ip, config.server.port, secret); } - // EE-TLS (FakeTLS) - if config.modes.tls { - let domain_hex = hex::encode(&config.tls_domain); + if config.general.modes.tls { + let domain_hex = hex::encode(&config.censorship.tls_domain); info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", - public_ip, config.port, secret, domain_hex); + public_ip, config.server.port, secret, domain_hex); } } else { warn!("User '{}' specified in show_link not found in users list", user_name); @@ -153,6 +165,7 @@ async fn main() -> Result<(), Box> { let stats = stats.clone(); let upstream_manager = upstream_manager.clone(); let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); tokio::spawn(async move { loop { @@ -162,6 +175,7 @@ async fn main() -> Result<(), Box> { let stats = stats.clone(); let upstream_manager = upstream_manager.clone(); let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); tokio::spawn(async move { if let Err(e) = ClientHandler::new( @@ -170,10 +184,10 @@ async fn main() -> Result<(), Box> { config, stats, upstream_manager, - replay_checker // Pass global checker + replay_checker, + buffer_pool ).run().await { // Log only relevant errors - // debug!("Connection error: {}", e); } }); } diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 11e3a81..67a3728 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -14,7 +14,7 @@ use crate::protocol::constants::*; use crate::protocol::tls; use crate::stats::{Stats, ReplayChecker}; use crate::transport::{configure_client_socket, UpstreamManager}; -use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter}; +use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter, BufferPool}; use crate::crypto::AesCtr; use super::handshake::{ @@ -35,6 +35,7 @@ pub struct RunningClientHandler { stats: Arc, replay_checker: Arc, upstream_manager: Arc, + buffer_pool: Arc, } impl ClientHandler { @@ -45,11 +46,9 @@ impl ClientHandler { config: Arc, stats: Arc, upstream_manager: Arc, - replay_checker: Arc, // CHANGED: Accept global checker + replay_checker: Arc, + buffer_pool: Arc, ) -> RunningClientHandler { - // CHANGED: Removed local creation of ReplayChecker. - // It is now passed from main.rs to ensure global replay protection. - RunningClientHandler { stream, peer, @@ -57,6 +56,7 @@ impl ClientHandler { stats, replay_checker, upstream_manager, + buffer_pool, } } } @@ -72,14 +72,14 @@ impl RunningClientHandler { // Configure socket if let Err(e) = configure_client_socket( &self.stream, - self.config.client_keepalive, - self.config.client_ack_timeout, + self.config.timeouts.client_keepalive, + self.config.timeouts.client_ack, ) { debug!(peer = %peer, error = %e, "Failed to configure client socket"); } // Perform handshake with timeout - let handshake_timeout = Duration::from_secs(self.config.client_handshake_timeout); + let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake); // Clone stats for error handling block let stats = self.stats.clone(); @@ -139,7 +139,9 @@ impl RunningClientHandler { if tls_len < 512 { debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); self.stats.increment_connects_bad(); - handle_bad_client(self.stream, &first_bytes, &self.config).await; + // FIX: Split stream into reader/writer for handle_bad_client + let (reader, writer) = self.stream.into_split(); + handle_bad_client(reader, writer, &first_bytes, &self.config).await; return Ok(()); } @@ -152,6 +154,7 @@ impl RunningClientHandler { let config = self.config.clone(); let replay_checker = self.replay_checker.clone(); let stats = self.stats.clone(); + let buffer_pool = self.buffer_pool.clone(); // Split stream for reading/writing let (read_half, write_half) = self.stream.into_split(); @@ -166,8 +169,9 @@ impl RunningClientHandler { &replay_checker, ).await { HandshakeResult::Success(result) => result, - HandshakeResult::BadClient => { + HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); + handle_bad_client(reader, writer, &handshake, &config).await; return Ok(()); } HandshakeResult::Error(e) => return Err(e), @@ -190,27 +194,23 @@ impl RunningClientHandler { true, ).await { HandshakeResult::Success(result) => result, - HandshakeResult::BadClient => { + HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); + // Valid TLS but invalid MTProto - drop + debug!(peer = %peer, "Valid TLS but invalid MTProto handshake - dropping"); return Ok(()); } HandshakeResult::Error(e) => return Err(e), }; - // Handle authenticated client - // We can't use self.handle_authenticated_inner because self is partially moved - // So we call it as an associated function or method on a new struct, - // or just inline the logic / use a static method. - // Since handle_authenticated_inner needs self.upstream_manager and self.stats, - // we should pass them explicitly. - Self::handle_authenticated_static( crypto_reader, crypto_writer, success, self.upstream_manager, self.stats, - self.config + self.config, + buffer_pool ).await } @@ -222,10 +222,12 @@ impl RunningClientHandler { let peer = self.peer; // Check if non-TLS modes are enabled - if !self.config.modes.classic && !self.config.modes.secure { + if !self.config.general.modes.classic && !self.config.general.modes.secure { debug!(peer = %peer, "Non-TLS modes disabled"); self.stats.increment_connects_bad(); - handle_bad_client(self.stream, &first_bytes, &self.config).await; + // FIX: Split stream into reader/writer for handle_bad_client + let (reader, writer) = self.stream.into_split(); + handle_bad_client(reader, writer, &first_bytes, &self.config).await; return Ok(()); } @@ -238,6 +240,7 @@ impl RunningClientHandler { let config = self.config.clone(); let replay_checker = self.replay_checker.clone(); let stats = self.stats.clone(); + let buffer_pool = self.buffer_pool.clone(); // Split stream let (read_half, write_half) = self.stream.into_split(); @@ -253,8 +256,9 @@ impl RunningClientHandler { false, ).await { HandshakeResult::Success(result) => result, - HandshakeResult::BadClient => { + HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); + handle_bad_client(reader, writer, &handshake, &config).await; return Ok(()); } HandshakeResult::Error(e) => return Err(e), @@ -266,11 +270,12 @@ impl RunningClientHandler { success, self.upstream_manager, self.stats, - self.config + self.config, + buffer_pool ).await } - /// Static version of handle_authenticated_inner to avoid ownership issues + /// Static version of handle_authenticated_inner async fn handle_authenticated_static( client_reader: CryptoReader, client_writer: CryptoWriter, @@ -278,6 +283,7 @@ impl RunningClientHandler { upstream_manager: Arc, stats: Arc, config: Arc, + buffer_pool: Arc, ) -> Result<()> where R: AsyncRead + Unpin + Send + 'static, @@ -300,7 +306,7 @@ impl RunningClientHandler { dc = success.dc_idx, dc_addr = %dc_addr, proto = ?success.proto_tag, - fast_mode = config.fast_mode, + fast_mode = config.general.fast_mode, "Connecting to Telegram" ); @@ -322,7 +328,7 @@ impl RunningClientHandler { stats.increment_user_connects(user); stats.increment_user_curr_connects(user); - // Relay traffic + // Relay traffic using buffer pool let relay_result = relay_bidirectional( client_reader, client_writer, @@ -330,6 +336,7 @@ impl RunningClientHandler { tg_writer, user, Arc::clone(&stats), + buffer_pool, ).await; // Update stats @@ -346,14 +353,14 @@ impl RunningClientHandler { /// Check user limits (static version) fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> { // Check expiration - if let Some(expiration) = config.user_expirations.get(user) { + if let Some(expiration) = config.access.user_expirations.get(user) { if chrono::Utc::now() > *expiration { return Err(ProxyError::UserExpired { user: user.to_string() }); } } // Check connection limit - if let Some(limit) = config.user_max_tcp_conns.get(user) { + if let Some(limit) = config.access.user_max_tcp_conns.get(user) { let current = stats.get_user_curr_connects(user); if current >= *limit as u64 { return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() }); @@ -361,7 +368,7 @@ impl RunningClientHandler { } // Check data quota - if let Some(quota) = config.user_data_quota.get(user) { + if let Some(quota) = config.access.user_data_quota.get(user) { let used = stats.get_user_total_octets(user); if used >= *quota { return Err(ProxyError::DataQuotaExceeded { user: user.to_string() }); @@ -375,7 +382,7 @@ impl RunningClientHandler { fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result { let idx = (dc_idx.abs() - 1) as usize; - let datacenters = if config.prefer_ipv6 { + let datacenters = if config.general.prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 @@ -399,7 +406,7 @@ impl RunningClientHandler { success.proto_tag, &success.dec_key, // Client's dec key success.dec_iv, - config.fast_mode, + config.general.fast_mode, ); // Encrypt nonce diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 19e8baa..241644f 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -42,7 +42,7 @@ pub async fn handle_tls_handshake( peer: SocketAddr, config: &ProxyConfig, replay_checker: &ReplayChecker, -) -> HandshakeResult<(FakeTlsReader, FakeTlsWriter, String)> +) -> HandshakeResult<(FakeTlsReader, FakeTlsWriter, String), R, W> where R: AsyncRead + Unpin, W: AsyncWrite + Unpin, @@ -52,7 +52,7 @@ where // Check minimum length if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { debug!(peer = %peer, "TLS handshake too short"); - return HandshakeResult::BadClient; + return HandshakeResult::BadClient { reader, writer }; } // Extract digest for replay check @@ -62,11 +62,11 @@ where // Check for replay if replay_checker.check_tls_digest(digest_half) { warn!(peer = %peer, "TLS replay attack detected"); - return HandshakeResult::BadClient; + return HandshakeResult::BadClient { reader, writer }; } // Build secrets list - let secrets: Vec<(String, Vec)> = config.users.iter() + let secrets: Vec<(String, Vec)> = config.access.users.iter() .filter_map(|(name, hex)| { hex::decode(hex).ok().map(|bytes| (name.clone(), bytes)) }) @@ -78,19 +78,19 @@ where let validation = match tls::validate_tls_handshake( handshake, &secrets, - config.ignore_time_skew, + config.access.ignore_time_skew, ) { Some(v) => v, None => { debug!(peer = %peer, "TLS handshake validation failed - no matching user"); - return HandshakeResult::BadClient; + return HandshakeResult::BadClient { reader, writer }; } }; // Get secret for response let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { Some((_, s)) => s, - None => return HandshakeResult::BadClient, + None => return HandshakeResult::BadClient { reader, writer }, }; // Build and send response @@ -98,7 +98,7 @@ where secret, &validation.digest, &validation.session_id, - config.fake_cert_len, + config.censorship.fake_cert_len, ); debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello"); @@ -136,7 +136,7 @@ pub async fn handle_mtproto_handshake( config: &ProxyConfig, replay_checker: &ReplayChecker, is_tls: bool, -) -> HandshakeResult<(CryptoReader, CryptoWriter, HandshakeSuccess)> +) -> HandshakeResult<(CryptoReader, CryptoWriter, HandshakeSuccess), R, W> where R: AsyncRead + Unpin + Send, W: AsyncWrite + Unpin + Send, @@ -155,14 +155,14 @@ where // Check for replay if replay_checker.check_handshake(dec_prekey_iv) { warn!(peer = %peer, "MTProto replay attack detected"); - return HandshakeResult::BadClient; + return HandshakeResult::BadClient { reader, writer }; } // Reversed for encryption direction let enc_prekey_iv: Vec = dec_prekey_iv.iter().rev().copied().collect(); // Try each user's secret - for (user, secret_hex) in &config.users { + for (user, secret_hex) in &config.access.users { let secret = match hex::decode(secret_hex) { Ok(s) => s, Err(_) => continue, @@ -208,9 +208,9 @@ where // Check if mode is enabled let mode_ok = match proto_tag { ProtoTag::Secure => { - if is_tls { config.modes.tls } else { config.modes.secure } + if is_tls { config.general.modes.tls } else { config.general.modes.secure } } - ProtoTag::Intermediate | ProtoTag::Abridged => config.modes.classic, + ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic, }; if !mode_ok { @@ -270,7 +270,7 @@ where } debug!(peer = %peer, "MTProto handshake: no matching user found"); - HandshakeResult::BadClient + HandshakeResult::BadClient { reader, writer } } /// Generate nonce for Telegram connection diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 05e018e..27bb867 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -1,35 +1,73 @@ //! Masking - forward unrecognized traffic to mask host use std::time::Duration; +use std::str; use tokio::net::TcpStream; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; use tokio::time::timeout; use tracing::debug; use crate::config::ProxyConfig; -use crate::transport::set_linger_zero; const MASK_TIMEOUT: Duration = Duration::from_secs(5); const MASK_BUFFER_SIZE: usize = 8192; +/// Detect client type based on initial data +fn detect_client_type(data: &[u8]) -> &'static str { + // Check for HTTP request + if data.len() > 4 { + if data.starts_with(b"GET ") || data.starts_with(b"POST") || + data.starts_with(b"HEAD") || data.starts_with(b"PUT ") || + data.starts_with(b"DELETE") || data.starts_with(b"OPTIONS") { + return "HTTP"; + } + } + + // Check for TLS ClientHello (0x16 = handshake, 0x03 0x01-0x03 = TLS version) + if data.len() > 3 && data[0] == 0x16 && data[1] == 0x03 { + return "TLS-scanner"; + } + + // Check for SSH + if data.starts_with(b"SSH-") { + return "SSH"; + } + + // Port scanner (very short data) + if data.len() < 10 { + return "port-scanner"; + } + + "unknown" +} + /// Handle a bad client by forwarding to mask host -pub async fn handle_bad_client( - client: TcpStream, +pub async fn handle_bad_client( + mut reader: R, + mut writer: W, initial_data: &[u8], config: &ProxyConfig, -) { - if !config.mask { +) +where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, +{ + if !config.censorship.mask { // Masking disabled, just consume data - consume_client_data(client).await; + consume_client_data(reader).await; return; } - let mask_host = config.mask_host.as_deref() - .unwrap_or(&config.tls_domain); - let mask_port = config.mask_port; + let client_type = detect_client_type(initial_data); + + let mask_host = config.censorship.mask_host.as_deref() + .unwrap_or(&config.censorship.tls_domain); + let mask_port = config.censorship.mask_port; debug!( + client_type = client_type, host = %mask_host, port = mask_port, + data_len = initial_data.len(), "Forwarding bad client to mask host" ); @@ -40,33 +78,32 @@ pub async fn handle_bad_client( TcpStream::connect(&mask_addr) ).await; - let mut mask_stream = match connect_result { + let mask_stream = match connect_result { Ok(Ok(s)) => s, Ok(Err(e)) => { debug!(error = %e, "Failed to connect to mask host"); - consume_client_data(client).await; + consume_client_data(reader).await; return; } Err(_) => { debug!("Timeout connecting to mask host"); - consume_client_data(client).await; + consume_client_data(reader).await; return; } }; + let (mut mask_read, mut mask_write) = mask_stream.into_split(); + // Send initial data to mask host - if mask_stream.write_all(initial_data).await.is_err() { + if mask_write.write_all(initial_data).await.is_err() { return; } // Relay traffic - let (mut client_read, mut client_write) = client.into_split(); - let (mut mask_read, mut mask_write) = mask_stream.into_split(); - let c2m = tokio::spawn(async move { let mut buf = vec![0u8; MASK_BUFFER_SIZE]; loop { - match client_read.read(&mut buf).await { + match reader.read(&mut buf).await { Ok(0) | Err(_) => { let _ = mask_write.shutdown().await; break; @@ -85,11 +122,11 @@ pub async fn handle_bad_client( loop { match mask_read.read(&mut buf).await { Ok(0) | Err(_) => { - let _ = client_write.shutdown().await; + let _ = writer.shutdown().await; break; } Ok(n) => { - if client_write.write_all(&buf[..n]).await.is_err() { + if writer.write_all(&buf[..n]).await.is_err() { break; } } @@ -105,9 +142,9 @@ pub async fn handle_bad_client( } /// Just consume all data from client without responding -async fn consume_client_data(mut client: TcpStream) { +async fn consume_client_data(mut reader: R) { let mut buf = vec![0u8; MASK_BUFFER_SIZE]; - while let Ok(n) = client.read(&mut buf).await { + while let Ok(n) = reader.read(&mut buf).await { if n == 0 { break; } diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index f90b247..22e5b7c 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -7,14 +7,10 @@ use tokio::time::Instant; use tracing::{debug, trace, warn, info}; use crate::error::Result; use crate::stats::Stats; +use crate::stream::BufferPool; use std::sync::atomic::{AtomicU64, Ordering}; -// CHANGED: Reduced from 128KB to 16KB to match TLS record size and prevent bufferbloat. -// This is critical for iOS clients to maintain proper TCP flow control during uploads. -const BUFFER_SIZE: usize = 16384; - // Activity timeout for iOS compatibility (30 minutes) -// iOS does not support TCP_USER_TIMEOUT, so we implement application-level timeout const ACTIVITY_TIMEOUT_SECS: u64 = 1800; /// Relay data bidirectionally between client and server @@ -25,6 +21,7 @@ pub async fn relay_bidirectional( mut server_writer: SW, user: &str, stats: Arc, + buffer_pool: Arc, ) -> Result<()> where CR: AsyncRead + Unpin + Send + 'static, @@ -35,7 +32,6 @@ where let user_c2s = user.to_string(); let user_s2c = user.to_string(); - // Используем Arc::clone вместо stats.clone() let stats_c2s = Arc::clone(&stats); let stats_s2c = Arc::clone(&stats); @@ -44,26 +40,29 @@ where let c2s_bytes_clone = Arc::clone(&c2s_bytes); let s2c_bytes_clone = Arc::clone(&s2c_bytes); - // Activity timeout for iOS compatibility let activity_timeout = Duration::from_secs(ACTIVITY_TIMEOUT_SECS); - // Client -> Server task with activity timeout + let pool_c2s = buffer_pool.clone(); + let pool_s2c = buffer_pool.clone(); + + // Client -> Server task let c2s = tokio::spawn(async move { - let mut buf = vec![0u8; BUFFER_SIZE]; + // Get buffer from pool + let mut buf = pool_c2s.get(); let mut total_bytes = 0u64; + let mut prev_total_bytes = 0u64; let mut msg_count = 0u64; let mut last_activity = Instant::now(); let mut last_log = Instant::now(); loop { - // Read with timeout to prevent infinite hang on iOS + // Read with timeout let read_result = tokio::time::timeout( activity_timeout, client_reader.read(&mut buf) ).await; match read_result { - // Timeout - no activity for too long Err(_) => { warn!( user = %user_c2s, @@ -76,7 +75,6 @@ where break; } - // Read successful Ok(Ok(0)) => { debug!( user = %user_c2s, @@ -101,21 +99,26 @@ where user = %user_c2s, bytes = n, total = total_bytes, - data_preview = %hex::encode(&buf[..n.min(32)]), "C->S data" ); - // Log activity every 10 seconds for large transfers - if last_log.elapsed() > Duration::from_secs(10) { - let rate = total_bytes as f64 / last_log.elapsed().as_secs_f64(); - info!( + // Log activity every 10 seconds with correct rate + let elapsed = last_log.elapsed(); + if elapsed > Duration::from_secs(10) { + let delta = total_bytes - prev_total_bytes; + let rate = delta as f64 / elapsed.as_secs_f64(); + + // Changed to DEBUG to reduce log spam + debug!( user = %user_c2s, total_bytes = total_bytes, msgs = msg_count, rate_kbps = (rate / 1024.0) as u64, "C->S transfer in progress" ); + last_log = Instant::now(); + prev_total_bytes = total_bytes; } if let Err(e) = server_writer.write_all(&buf[..n]).await { @@ -136,23 +139,23 @@ where } }); - // Server -> Client task with activity timeout + // Server -> Client task let s2c = tokio::spawn(async move { - let mut buf = vec![0u8; BUFFER_SIZE]; + // Get buffer from pool + let mut buf = pool_s2c.get(); let mut total_bytes = 0u64; + let mut prev_total_bytes = 0u64; let mut msg_count = 0u64; let mut last_activity = Instant::now(); let mut last_log = Instant::now(); loop { - // Read with timeout to prevent infinite hang on iOS let read_result = tokio::time::timeout( activity_timeout, server_reader.read(&mut buf) ).await; match read_result { - // Timeout - no activity for too long Err(_) => { warn!( user = %user_s2c, @@ -165,7 +168,6 @@ where break; } - // Read successful Ok(Ok(0)) => { debug!( user = %user_s2c, @@ -190,21 +192,25 @@ where user = %user_s2c, bytes = n, total = total_bytes, - data_preview = %hex::encode(&buf[..n.min(32)]), "S->C data" ); - // Log activity every 10 seconds for large transfers - if last_log.elapsed() > Duration::from_secs(10) { - let rate = total_bytes as f64 / last_log.elapsed().as_secs_f64(); - info!( + let elapsed = last_log.elapsed(); + if elapsed > Duration::from_secs(10) { + let delta = total_bytes - prev_total_bytes; + let rate = delta as f64 / elapsed.as_secs_f64(); + + // Changed to DEBUG to reduce log spam + debug!( user = %user_s2c, total_bytes = total_bytes, msgs = msg_count, rate_kbps = (rate / 1024.0) as u64, "S->C transfer in progress" ); + last_log = Instant::now(); + prev_total_bytes = total_bytes; } if let Err(e) = client_writer.write_all(&buf[..n]).await { diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 6ae5af2..9fa495d 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -4,9 +4,11 @@ use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use std::time::Instant; use dashmap::DashMap; -use parking_lot::RwLock; +use parking_lot::{RwLock, Mutex}; use lru::LruCache; use std::num::NonZeroUsize; +use std::hash::{Hash, Hasher}; +use std::collections::hash_map::DefaultHasher; /// Thread-safe statistics #[derive(Default)] @@ -141,37 +143,57 @@ impl Stats { } } -// Arc Hightech Stats :D - -/// Replay attack checker using LRU cache +/// Sharded Replay attack checker using LRU cache +/// Uses multiple independent LRU caches to reduce lock contention pub struct ReplayChecker { - handshakes: RwLock, ()>>, - tls_digests: RwLock, ()>>, + shards: Vec, ()>>>, + shard_mask: usize, } impl ReplayChecker { - pub fn new(capacity: usize) -> Self { - let cap = NonZeroUsize::new(capacity.max(1)).unwrap(); + /// Create new replay checker with specified capacity per shard + /// Total capacity = capacity * num_shards + pub fn new(total_capacity: usize) -> Self { + // Use 64 shards for good concurrency + let num_shards = 64; + let shard_capacity = (total_capacity / num_shards).max(1); + let cap = NonZeroUsize::new(shard_capacity).unwrap(); + + let mut shards = Vec::with_capacity(num_shards); + for _ in 0..num_shards { + shards.push(Mutex::new(LruCache::new(cap))); + } + Self { - handshakes: RwLock::new(LruCache::new(cap)), - tls_digests: RwLock::new(LruCache::new(cap)), + shards, + shard_mask: num_shards - 1, } } + fn get_shard(&self, key: &[u8]) -> usize { + let mut hasher = DefaultHasher::new(); + key.hash(&mut hasher); + (hasher.finish() as usize) & self.shard_mask + } + pub fn check_handshake(&self, data: &[u8]) -> bool { - self.handshakes.read().contains(&data.to_vec()) + let shard_idx = self.get_shard(data); + self.shards[shard_idx].lock().contains(&data.to_vec()) } pub fn add_handshake(&self, data: &[u8]) { - self.handshakes.write().put(data.to_vec(), ()); + let shard_idx = self.get_shard(data); + self.shards[shard_idx].lock().put(data.to_vec(), ()); } pub fn check_tls_digest(&self, data: &[u8]) -> bool { - self.tls_digests.read().contains(&data.to_vec()) + let shard_idx = self.get_shard(data); + self.shards[shard_idx].lock().contains(&data.to_vec()) } pub fn add_tls_digest(&self, data: &[u8]) { - self.tls_digests.write().put(data.to_vec(), ()); + let shard_idx = self.get_shard(data); + self.shards[shard_idx].lock().put(data.to_vec(), ()); } } @@ -183,7 +205,6 @@ mod tests { fn test_stats_shared_counters() { let stats = Arc::new(Stats::new()); - // Симулируем использование из разных "задач" let stats1 = Arc::clone(&stats); let stats2 = Arc::clone(&stats); @@ -191,33 +212,20 @@ mod tests { stats2.increment_connects_all(); stats1.increment_connects_all(); - // Все инкременты должны быть видны assert_eq!(stats.get_connects_all(), 3); } #[test] - fn test_user_stats_shared() { - let stats = Arc::new(Stats::new()); + fn test_replay_checker_sharding() { + let checker = ReplayChecker::new(100); + let data1 = b"test1"; + let data2 = b"test2"; - let stats1 = Arc::clone(&stats); - let stats2 = Arc::clone(&stats); + checker.add_handshake(data1); + assert!(checker.check_handshake(data1)); + assert!(!checker.check_handshake(data2)); - stats1.add_user_octets_from("user1", 100); - stats2.add_user_octets_from("user1", 200); - stats1.add_user_octets_to("user1", 50); - - assert_eq!(stats.get_user_total_octets("user1"), 350); - } - - #[test] - fn test_concurrent_user_connects() { - let stats = Arc::new(Stats::new()); - - stats.increment_user_curr_connects("user1"); - stats.increment_user_curr_connects("user1"); - assert_eq!(stats.get_user_curr_connects("user1"), 2); - - stats.decrement_user_curr_connects("user1"); - assert_eq!(stats.get_user_curr_connects("user1"), 1); + checker.add_handshake(data2); + assert!(checker.check_handshake(data2)); } } \ No newline at end of file diff --git a/src/stream/crypto_stream.rs b/src/stream/crypto_stream.rs index b06bc7c..4705fe6 100644 --- a/src/stream/crypto_stream.rs +++ b/src/stream/crypto_stream.rs @@ -45,6 +45,11 @@ //! - when upstream is Pending but pending still has room: accept `to_accept` bytes and //! encrypt+append ciphertext directly into pending (in-place encryption of appended range) +//! Encrypted stream wrappers using AES-CTR +//! +//! This module provides stateful async stream wrappers that handle +//! encryption/decryption with proper partial read/write handling. + use bytes::{Bytes, BytesMut}; use std::io::{self, ErrorKind, Result}; use std::pin::Pin; @@ -58,8 +63,9 @@ use super::state::{StreamState, YieldBuffer}; // ============= Constants ============= /// Maximum size for pending ciphertext buffer (bounded backpressure). -/// 512 KiB tends to work well for mobile networks and avoids huge latency spikes. -const MAX_PENDING_WRITE: usize = 524_288; +/// Reduced to 64KB to prevent bufferbloat on mobile networks. +/// 512KB was causing high latency on 3G/LTE connections. +const MAX_PENDING_WRITE: usize = 64 * 1024; /// Default read buffer capacity (reader mostly decrypts in-place into caller buffer). const DEFAULT_READ_CAPACITY: usize = 16 * 1024; @@ -99,22 +105,6 @@ impl StreamState for CryptoReaderState { // ============= CryptoReader ============= /// Reader that decrypts data using AES-CTR with proper state machine. -/// -/// This reader handles partial reads correctly by maintaining internal state -/// and never losing any data that has been read from upstream. -/// -/// # State Machine -/// -/// ┌──────────┐ read ┌──────────┐ -/// │ Idle │ ------------> │ Yielding │ -/// │ │ <------------ │ │ -/// └──────────┘ drained └──────────┘ -/// │ │ -/// │ errors │ -/// ▼ ▼ -/// ┌──────────────────────────────────────┐ -/// │ Poisoned │ -/// └──────────────────────────────────────┘ pub struct CryptoReader { upstream: R, decryptor: AesCtr, @@ -315,10 +305,6 @@ impl CryptoReader { // ============= Pending Ciphertext ============= /// Pending ciphertext buffer with explicit position and strict max size. -/// -/// - append plaintext then encrypt appended range in-place - one-touch copy, no extra Vec -/// - move ciphertext from scratch into pending without copying -/// - explicit compaction behavior for long-lived connections #[derive(Debug)] struct PendingCiphertext { buf: BytesMut, @@ -361,15 +347,13 @@ impl PendingCiphertext { } // Compact when a large prefix was consumed. - if self.pos >= 32 * 1024 { + if self.pos >= 16 * 1024 { let _ = self.buf.split_to(self.pos); self.pos = 0; } } /// Replace the entire pending ciphertext by moving `src` in (swap, no copy). - /// - /// Precondition: src.len() <= max_len. fn replace_with(&mut self, mut src: BytesMut) { debug_assert!(src.len() <= self.max_len); @@ -381,12 +365,6 @@ impl PendingCiphertext { } /// Append plaintext and encrypt appended range in-place. - /// - /// This is the high-throughput buffering path: - /// - copy plaintext into pending buffer - /// - encrypt only the newly appended bytes - /// - /// CTR state advances by exactly plaintext.len(). fn push_encrypted(&mut self, encryptor: &mut AesCtr, plaintext: &[u8]) -> Result<()> { if plaintext.is_empty() { return Ok(()); @@ -444,21 +422,10 @@ impl StreamState for CryptoWriterState { // ============= CryptoWriter ============= /// Writer that encrypts data using AES-CTR with correct async semantics. -/// -/// - CTR state advances exactly by the number of bytes we report as written -/// - If upstream blocks, ciphertext is buffered/bounded -/// - Backpressure is applied when buffer is full pub struct CryptoWriter { upstream: W, encryptor: AesCtr, state: CryptoWriterState, - - /// Scratch ciphertext for fast "write-through" path. - /// - /// Flow: - /// - encrypt plaintext into scratch - /// - try upstream write - /// - if Pending/partial: move remainder into pending without re-encrypting scratch: BytesMut, } @@ -531,9 +498,6 @@ impl CryptoWriter { } /// Select how many plaintext bytes can be accepted in buffering path - /// - /// Requirement: worst case - upstream pending, must buffer all ciphertext - /// for the accepted bytes fn select_to_accept_for_buffering(state: &CryptoWriterState, buf_len: usize) -> usize { if buf_len == 0 { return 0; @@ -557,11 +521,6 @@ impl CryptoWriter { impl CryptoWriter { /// Flush as much pending ciphertext as possible - /// - /// Returns - /// - Ready(Ok(())) if all pending is flushed or was none - /// - Pending if upstream would block - /// - Ready(Err(_)) on error fn poll_flush_pending(&mut self, cx: &mut Context<'_>) -> Poll> { loop { match &mut self.state { @@ -606,14 +565,6 @@ impl CryptoWriter { Poll::Ready(Ok(n)) => { pending.advance(n); - - trace!( - flushed = n, - pending_left = pending.pending_len(), - "CryptoWriter: flushed pending ciphertext" - ); - - // continue loop to flush more continue; } } @@ -643,9 +594,6 @@ impl AsyncWrite for CryptoWriter { } // 1) If we have pending ciphertext, prioritize flushing it - // If upstream pending - // -> still accept some plaintext ONLY if we can buffer - // all ciphertext for the accepted portion - bounded if matches!(this.state, CryptoWriterState::Flushing { .. }) { match this.poll_flush_pending(cx) { Poll::Ready(Ok(())) => { @@ -654,8 +602,6 @@ impl AsyncWrite for CryptoWriter { Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Pending => { // Upstream blocked. Apply ideal backpressure - // - accept up to remaining pending capacity - // - if no capacity -> pending let to_accept = Self::select_to_accept_for_buffering(&this.state, buf.len()); @@ -670,11 +616,10 @@ impl AsyncWrite for CryptoWriter { let plaintext = &buf[..to_accept]; - // Disjoint borrows: borrow encryptor and state separately via a match + // Disjoint borrows let encryptor = &mut this.encryptor; let pending = Self::ensure_pending(&mut this.state); - // Should not WouldBlock because to_accept <= remaining_capacity if let Err(e) = pending.push_encrypted(encryptor, plaintext) { if e.kind() == ErrorKind::WouldBlock { return Poll::Pending; @@ -682,13 +627,6 @@ impl AsyncWrite for CryptoWriter { return Poll::Ready(Err(e)); } - trace!( - accepted = to_accept, - pending_len = pending.pending_len(), - pending_cap = pending.remaining_capacity(), - "CryptoWriter: upstream Pending, buffered ciphertext (accepted plaintext)" - ); - return Poll::Ready(Ok(to_accept)); } } @@ -697,9 +635,6 @@ impl AsyncWrite for CryptoWriter { // 2) Fast path: pending empty -> write-through debug_assert!(matches!(this.state, CryptoWriterState::Idle)); - // Worst-case buffering requirement - // - If upstream becomes pending -> buffer full ciphertext for accepted bytes - // -> accept at most MAX_PENDING_WRITE per poll_write call let to_accept = buf.len().min(MAX_PENDING_WRITE); let plaintext = &buf[..to_accept]; @@ -708,18 +643,11 @@ impl AsyncWrite for CryptoWriter { match Pin::new(&mut this.upstream).poll_write(cx, &this.scratch) { Poll::Pending => { // Upstream blocked: buffer FULL ciphertext for accepted bytes. - // Move scratch into pending without copying. let ciphertext = std::mem::take(&mut this.scratch); let pending = Self::ensure_pending(&mut this.state); pending.replace_with(ciphertext); - trace!( - accepted = to_accept, - pending_len = pending.pending_len(), - "CryptoWriter: write-through got Pending, buffered full ciphertext" - ); - Poll::Ready(Ok(to_accept)) } @@ -736,26 +664,11 @@ impl AsyncWrite for CryptoWriter { Poll::Ready(Ok(n)) => { if n == this.scratch.len() { - trace!( - accepted = to_accept, - ciphertext_len = this.scratch.len(), - "CryptoWriter: write-through wrote full ciphertext directly" - ); this.scratch.clear(); return Poll::Ready(Ok(to_accept)); } - // Partial upstream write of ciphertext: - // We accepted `to_accept` plaintext bytes, CTR already advanced for to_accept - // Must buffer the remainder ciphertext - warn!( - accepted = to_accept, - ciphertext_len = this.scratch.len(), - written_ciphertext = n, - "CryptoWriter: partial upstream write, buffering remainder" - ); - - // Split off remainder without copying + // Partial upstream write of ciphertext let remainder = this.scratch.split_off(n); this.scratch.clear(); @@ -788,7 +701,6 @@ impl AsyncWrite for CryptoWriter { let this = self.get_mut(); // Best-effort flush pending ciphertext before shutdown - // If upstream blocks, proceed to shutdown anyway match this.poll_flush_pending(cx) { Poll::Pending => { debug!( @@ -807,9 +719,6 @@ impl AsyncWrite for CryptoWriter { // ============= PassthroughStream ============= /// Passthrough stream for fast mode - no encryption/decryption -/// -/// Used when keys are set up so that client and Telegram use the same -/// encryption, allowing data to pass through without re-encryption pub struct PassthroughStream { inner: S, }