diff --git a/src/config/defaults.rs b/src/config/defaults.rs new file mode 100644 index 0000000..9df2813 --- /dev/null +++ b/src/config/defaults.rs @@ -0,0 +1,98 @@ +use std::net::IpAddr; +use std::collections::HashMap; +use serde::Deserialize; + +// Helper defaults kept private to the config module. +pub(crate) fn default_true() -> bool { + true +} + +pub(crate) fn default_port() -> u16 { + 443 +} + +pub(crate) fn default_tls_domain() -> String { + "www.google.com".to_string() +} + +pub(crate) fn default_mask_port() -> u16 { + 443 +} + +pub(crate) fn default_fake_cert_len() -> usize { + 2048 +} + +pub(crate) fn default_replay_check_len() -> usize { + 65_536 +} + +pub(crate) fn default_replay_window_secs() -> u64 { + 1800 +} + +pub(crate) fn default_handshake_timeout() -> u64 { + 15 +} + +pub(crate) fn default_connect_timeout() -> u64 { + 10 +} + +pub(crate) fn default_keepalive() -> u64 { + 60 +} + +pub(crate) fn default_ack_timeout() -> u64 { + 300 +} + +pub(crate) fn default_listen_addr() -> String { + "0.0.0.0".to_string() +} + +pub(crate) fn default_weight() -> u16 { + 1 +} + +pub(crate) fn default_metrics_whitelist() -> Vec { + vec!["127.0.0.1".parse().unwrap(), "::1".parse().unwrap()] +} + +pub(crate) fn default_prefer_4() -> u8 { + 4 +} + +pub(crate) fn default_unknown_dc_log_path() -> Option { + Some("unknown-dc.txt".to_string()) +} + +// Custom deserializer helpers + +#[derive(Deserialize)] +#[serde(untagged)] +pub(crate) enum OneOrMany { + One(String), + Many(Vec), +} + +pub(crate) fn deserialize_dc_overrides<'de, D>( + deserializer: D, +) -> std::result::Result>, D::Error> +where + D: serde::de::Deserializer<'de>, +{ + let raw: HashMap = HashMap::deserialize(deserializer)?; + let mut out = HashMap::new(); + for (dc, val) in raw { + let mut addrs = match val { + OneOrMany::One(s) => vec![s], + OneOrMany::Many(v) => v, + }; + addrs.retain(|s| !s.trim().is_empty()); + if !addrs.is_empty() { + out.insert(dc, addrs); + } + } + Ok(out) +} diff --git a/src/config/load.rs b/src/config/load.rs new file mode 100644 index 0000000..512b734 --- /dev/null +++ b/src/config/load.rs @@ -0,0 +1,295 @@ +use std::collections::HashMap; +use std::net::IpAddr; +use std::path::Path; + +use rand::Rng; +use tracing::warn; +use serde::{Serialize, Deserialize}; + +use crate::error::{ProxyError, Result}; + +use super::defaults::*; +use super::types::*; + +fn validate_network_cfg(net: &mut NetworkConfig) -> Result<()> { + if !net.ipv4 && matches!(net.ipv6, Some(false)) { + return Err(ProxyError::Config( + "Both ipv4 and ipv6 are disabled in [network]".to_string(), + )); + } + + if net.prefer != 4 && net.prefer != 6 { + return Err(ProxyError::Config( + "network.prefer must be 4 or 6".to_string(), + )); + } + + if !net.ipv4 && net.prefer == 4 { + warn!("prefer=4 but ipv4=false; forcing prefer=6"); + net.prefer = 6; + } + + if matches!(net.ipv6, Some(false)) && net.prefer == 6 { + warn!("prefer=6 but ipv6=false; forcing prefer=4"); + net.prefer = 4; + } + + Ok(()) +} + +// ============= Main Config ============= + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ProxyConfig { + #[serde(default)] + pub general: GeneralConfig, + + #[serde(default)] + pub network: NetworkConfig, + + #[serde(default)] + pub server: ServerConfig, + + #[serde(default)] + pub timeouts: TimeoutsConfig, + + #[serde(default)] + pub censorship: AntiCensorshipConfig, + + #[serde(default)] + pub access: AccessConfig, + + #[serde(default)] + pub upstreams: Vec, + + #[serde(default)] + pub show_link: ShowLink, + + /// DC address overrides for non-standard DCs (CDN, media, test, etc.) + /// Keys are DC indices as strings, values are one or more "ip:port" addresses. + /// Matches the C implementation's `proxy_for :` config directive. + /// Example in config.toml: + /// [dc_overrides] + /// "203" = ["149.154.175.100:443", "91.105.192.100:443"] + #[serde(default, deserialize_with = "deserialize_dc_overrides")] + pub dc_overrides: HashMap>, + + /// Default DC index (1-5) for unmapped non-standard DCs. + /// Matches the C implementation's `default ` config directive. + /// If not set, defaults to 2 (matching Telegram's official `default 2;` in proxy-multi.conf). + #[serde(default)] + pub default_dc: Option, +} + +impl ProxyConfig { + pub fn load>(path: P) -> Result { + let content = + std::fs::read_to_string(path).map_err(|e| ProxyError::Config(e.to_string()))?; + + let mut config: ProxyConfig = + toml::from_str(&content).map_err(|e| ProxyError::Config(e.to_string()))?; + + // Validate secrets. + for (user, secret) in &config.access.users { + if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 { + return Err(ProxyError::InvalidSecret { + user: user.clone(), + reason: "Must be 32 hex characters".to_string(), + }); + } + } + + // Validate tls_domain. + if config.censorship.tls_domain.is_empty() { + return Err(ProxyError::Config("tls_domain cannot be empty".to_string())); + } + + // Validate mask_unix_sock. + if let Some(ref sock_path) = config.censorship.mask_unix_sock { + if sock_path.is_empty() { + return Err(ProxyError::Config( + "mask_unix_sock cannot be empty".to_string(), + )); + } + #[cfg(unix)] + if sock_path.len() > 107 { + return Err(ProxyError::Config(format!( + "mask_unix_sock path too long: {} bytes (max 107)", + sock_path.len() + ))); + } + #[cfg(not(unix))] + return Err(ProxyError::Config( + "mask_unix_sock is only supported on Unix platforms".to_string(), + )); + + if config.censorship.mask_host.is_some() { + return Err(ProxyError::Config( + "mask_unix_sock and mask_host are mutually exclusive".to_string(), + )); + } + } + + // Default mask_host to tls_domain if not set and no unix socket configured. + if config.censorship.mask_host.is_none() && config.censorship.mask_unix_sock.is_none() { + config.censorship.mask_host = Some(config.censorship.tls_domain.clone()); + } + + // Migration: prefer_ipv6 -> network.prefer. + if config.general.prefer_ipv6 { + if config.network.prefer == 4 { + config.network.prefer = 6; + } + warn!("prefer_ipv6 is deprecated, use [network].prefer = 6"); + } + + // Auto-enable NAT probe when Middle Proxy is requested. + if config.general.use_middle_proxy && !config.general.middle_proxy_nat_probe { + config.general.middle_proxy_nat_probe = true; + warn!("Auto-enabled middle_proxy_nat_probe for middle proxy mode"); + } + + validate_network_cfg(&mut config.network)?; + + // Random fake_cert_len. + config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096); + + // Resolve listen_tcp: explicit value wins, otherwise auto-detect. + // If unix socket is set → TCP only when listen_addr_ipv4 or listeners are explicitly provided. + // If no unix socket → TCP always (backward compat). + let listen_tcp = config.server.listen_tcp.unwrap_or_else(|| { + if config.server.listen_unix_sock.is_some() { + // Unix socket present: TCP only if user explicitly set addresses or listeners. + config.server.listen_addr_ipv4.is_some() + || !config.server.listeners.is_empty() + } else { + true + } + }); + + // Migration: Populate listeners if empty (skip when listen_tcp = false). + if config.server.listeners.is_empty() && listen_tcp { + let ipv4_str = config.server.listen_addr_ipv4 + .as_deref() + .unwrap_or("0.0.0.0"); + if let Ok(ipv4) = ipv4_str.parse::() { + config.server.listeners.push(ListenerConfig { + ip: ipv4, + announce: None, + announce_ip: None, + }); + } + if let Some(ipv6_str) = &config.server.listen_addr_ipv6 { + if let Ok(ipv6) = ipv6_str.parse::() { + config.server.listeners.push(ListenerConfig { + ip: ipv6, + announce: None, + announce_ip: None, + }); + } + } + } + + // Migration: announce_ip → announce for each listener. + for listener in &mut config.server.listeners { + if listener.announce.is_none() && listener.announce_ip.is_some() { + listener.announce = Some(listener.announce_ip.unwrap().to_string()); + } + } + + // Migration: show_link (top-level) → general.links.show. + if !config.show_link.is_empty() && config.general.links.show.is_empty() { + config.general.links.show = config.show_link.clone(); + } + + // Migration: Populate upstreams if empty (Default Direct). + if config.upstreams.is_empty() { + config.upstreams.push(UpstreamConfig { + upstream_type: UpstreamType::Direct { interface: None }, + weight: 1, + enabled: true, + }); + } + + // Ensure default DC203 override is present. + config + .dc_overrides + .entry("203".to_string()) + .or_insert_with(|| vec!["91.105.192.100:443".to_string()]); + + Ok(config) + } + + pub fn validate(&self) -> Result<()> { + if self.access.users.is_empty() { + return Err(ProxyError::Config("No users configured".to_string())); + } + + if !self.general.modes.classic && !self.general.modes.secure && !self.general.modes.tls { + return Err(ProxyError::Config("No modes enabled".to_string())); + } + + if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') { + return Err(ProxyError::Config(format!( + "Invalid tls_domain: '{}'. Must be a valid domain name", + self.censorship.tls_domain + ))); + } + + if let Some(tag) = &self.general.ad_tag { + let zeros = "00000000000000000000000000000000"; + if tag == zeros { + warn!("ad_tag is all zeros; register a valid proxy tag via @MTProxybot to enable sponsored channel"); + } + if tag.len() != 32 || tag.chars().any(|c| !c.is_ascii_hexdigit()) { + warn!("ad_tag is not a 32-char hex string; ensure you use value issued by @MTProxybot"); + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn dc_overrides_allow_string_and_array() { + let toml = r#" + [dc_overrides] + "201" = "149.154.175.50:443" + "202" = ["149.154.167.51:443", "149.154.175.100:443"] + "#; + let cfg: ProxyConfig = toml::from_str(toml).unwrap(); + assert_eq!(cfg.dc_overrides["201"], vec!["149.154.175.50:443"]); + assert_eq!( + cfg.dc_overrides["202"], + vec!["149.154.167.51:443", "149.154.175.100:443"] + ); + } + + #[test] + fn dc_overrides_inject_dc203_default() { + let toml = r#" + [general] + use_middle_proxy = false + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_dc_override_test.toml"); + std::fs::write(&path, toml).unwrap(); + let cfg = ProxyConfig::load(&path).unwrap(); + assert!(cfg + .dc_overrides + .get("203") + .map(|v| v.contains(&"91.105.192.100:443".to_string())) + .unwrap_or(false)); + let _ = std::fs::remove_file(path); + } +} diff --git a/src/config/mod.rs b/src/config/mod.rs index a2a3120..a82d92b 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,869 +1,8 @@ -//! Configuration +//! Configuration. -use crate::error::{ProxyError, Result}; -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; -use serde::de::Deserializer; -use std::collections::HashMap; -use std::net::IpAddr; -use std::path::Path; -use tracing::warn; +pub(crate) mod defaults; +mod types; +mod load; -// ============= Helper Defaults ============= - -fn default_true() -> bool { - true -} -fn default_port() -> u16 { - 443 -} -fn default_tls_domain() -> String { - "www.google.com".to_string() -} -fn default_mask_port() -> u16 { - 443 -} -fn default_replay_check_len() -> usize { - 65536 -} -fn default_replay_window_secs() -> u64 { - 1800 -} -fn default_handshake_timeout() -> u64 { - 15 -} -fn default_connect_timeout() -> u64 { - 10 -} -fn default_keepalive() -> u64 { - 60 -} -fn default_ack_timeout() -> u64 { - 300 -} -fn default_listen_addr() -> String { - "0.0.0.0".to_string() -} -fn default_fake_cert_len() -> usize { - 2048 -} -fn default_weight() -> u16 { - 1 -} -fn default_metrics_whitelist() -> Vec { - vec!["127.0.0.1".parse().unwrap(), "::1".parse().unwrap()] -} - -fn default_prefer_4() -> u8 { - 4 -} - -fn default_unknown_dc_log_path() -> Option { - Some("unknown-dc.txt".to_string()) -} - -// ============= Custom Deserializers ============= - -#[derive(Deserialize)] -#[serde(untagged)] -enum OneOrMany { - One(String), - Many(Vec), -} - -fn deserialize_dc_overrides<'de, D>( - deserializer: D, -) -> std::result::Result>, D::Error> -where - D: Deserializer<'de>, -{ - let raw: HashMap = HashMap::deserialize(deserializer)?; - let mut out = HashMap::new(); - for (dc, val) in raw { - let mut addrs = match val { - OneOrMany::One(s) => vec![s], - OneOrMany::Many(v) => v, - }; - addrs.retain(|s| !s.trim().is_empty()); - if !addrs.is_empty() { - out.insert(dc, addrs); - } - } - Ok(out) -} - -// ============= Log Level ============= - -/// Logging verbosity level -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] -#[serde(rename_all = "lowercase")] -pub enum LogLevel { - /// All messages including trace (trace + debug + info + warn + error) - Debug, - /// Detailed operational logs (debug + info + warn + error) - Verbose, - /// Standard operational logs (info + warn + error) - #[default] - Normal, - /// Minimal output: only warnings and errors (warn + error). - /// Startup messages (config, DC connectivity, proxy links) are always shown - /// via info! before the filter is applied. - Silent, -} - -impl LogLevel { - /// Convert to tracing EnvFilter directive string - pub fn to_filter_str(&self) -> &'static str { - match self { - LogLevel::Debug => "trace", - LogLevel::Verbose => "debug", - LogLevel::Normal => "info", - LogLevel::Silent => "warn", - } - } - - /// Parse from a loose string (CLI argument) - pub fn from_str_loose(s: &str) -> Self { - match s.to_lowercase().as_str() { - "debug" | "trace" => LogLevel::Debug, - "verbose" => LogLevel::Verbose, - "normal" | "info" => LogLevel::Normal, - "silent" | "quiet" | "error" | "warn" => LogLevel::Silent, - _ => LogLevel::Normal, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn dc_overrides_allow_string_and_array() { - let toml = r#" - [dc_overrides] - "201" = "149.154.175.50:443" - "202" = ["149.154.167.51:443", "149.154.175.100:443"] - "#; - let cfg: ProxyConfig = toml::from_str(toml).unwrap(); - assert_eq!(cfg.dc_overrides["201"], vec!["149.154.175.50:443"]); - assert_eq!( - cfg.dc_overrides["202"], - vec!["149.154.167.51:443", "149.154.175.100:443"] - ); - } - - #[test] - fn dc_overrides_inject_dc203_default() { - let toml = r#" - [general] - use_middle_proxy = false - - [censorship] - tls_domain = "example.com" - - [access.users] - user = "00000000000000000000000000000000" - "#; - let dir = std::env::temp_dir(); - let path = dir.join("telemt_dc_override_test.toml"); - std::fs::write(&path, toml).unwrap(); - let cfg = ProxyConfig::load(&path).unwrap(); - assert!(cfg - .dc_overrides - .get("203") - .map(|v| v.contains(&"91.105.192.100:443".to_string())) - .unwrap_or(false)); - let _ = std::fs::remove_file(path); - } -} - -impl std::fmt::Display for LogLevel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - LogLevel::Debug => write!(f, "debug"), - LogLevel::Verbose => write!(f, "verbose"), - LogLevel::Normal => write!(f, "normal"), - LogLevel::Silent => write!(f, "silent"), - } - } -} - -fn validate_network_cfg(net: &mut NetworkConfig) -> Result<()> { - if !net.ipv4 && matches!(net.ipv6, Some(false)) { - return Err(ProxyError::Config( - "Both ipv4 and ipv6 are disabled in [network]".to_string(), - )); - } - - if net.prefer != 4 && net.prefer != 6 { - return Err(ProxyError::Config( - "network.prefer must be 4 or 6".to_string(), - )); - } - - if !net.ipv4 && net.prefer == 4 { - warn!("prefer=4 but ipv4=false; forcing prefer=6"); - net.prefer = 6; - } - - if matches!(net.ipv6, Some(false)) && net.prefer == 6 { - warn!("prefer=6 but ipv6=false; forcing prefer=4"); - net.prefer = 4; - } - - Ok(()) -} - -// ============= Sub-Configs ============= - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ProxyModes { - #[serde(default)] - pub classic: bool, - #[serde(default)] - pub secure: bool, - #[serde(default = "default_true")] - pub tls: bool, -} - -impl Default for ProxyModes { - fn default() -> Self { - Self { - classic: true, - secure: true, - tls: true, - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct NetworkConfig { - #[serde(default = "default_true")] - pub ipv4: bool, - - /// None = auto-detect IPv6 availability - #[serde(default)] - pub ipv6: Option, - - /// 4 or 6 - #[serde(default = "default_prefer_4")] - pub prefer: u8, - - #[serde(default)] - pub multipath: bool, -} - -impl Default for NetworkConfig { - fn default() -> Self { - Self { - ipv4: true, - ipv6: None, - prefer: 4, - multipath: false, - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct GeneralConfig { - #[serde(default)] - pub modes: ProxyModes, - - #[serde(default)] - pub prefer_ipv6: bool, - - #[serde(default = "default_true")] - pub fast_mode: bool, - - #[serde(default)] - pub use_middle_proxy: bool, - - #[serde(default)] - pub ad_tag: Option, - - /// Path to proxy-secret binary file (auto-downloaded if absent). - /// Infrastructure secret from https://core.telegram.org/getProxySecret - #[serde(default)] - pub proxy_secret_path: Option, - - /// Public IP override for middle-proxy NAT environments. - /// When set, this IP is used in ME key derivation and RPC_PROXY_REQ "our_addr". - #[serde(default)] - pub middle_proxy_nat_ip: Option, - - /// Enable STUN-based NAT probing to discover public IP:port for ME KDF. - #[serde(default)] - pub middle_proxy_nat_probe: bool, - - /// Optional STUN server address (host:port) for NAT probing. - #[serde(default)] - pub middle_proxy_nat_stun: Option, - - /// Ignore STUN/interface IP mismatch (keep using Middle Proxy even if NAT detected). - #[serde(default)] - pub stun_iface_mismatch_ignore: bool, - - /// Log unknown (non-standard) DC requests to a file (default: unknown-dc.txt). Set to null to disable. - #[serde(default = "default_unknown_dc_log_path")] - pub unknown_dc_log_path: Option, - - #[serde(default)] - pub log_level: LogLevel, - - /// Disable colored output in logs (useful for files/systemd) - #[serde(default)] - pub disable_colors: bool, - - /// [general.links] — proxy link generation overrides - #[serde(default)] - pub links: LinksConfig, -} - -/// `[general.links]` — proxy link generation settings. -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct LinksConfig { - /// List of usernames whose tg:// links to display at startup. - /// `"*"` = all users, `["alice", "bob"]` = specific users. - #[serde(default)] - pub show: ShowLink, - - /// Public hostname/IP for tg:// link generation (overrides detected IP). - #[serde(default)] - pub public_host: Option, - - /// Public port for tg:// link generation (overrides server.port). - #[serde(default)] - pub public_port: Option, -} - -impl Default for GeneralConfig { - fn default() -> Self { - Self { - modes: ProxyModes::default(), - prefer_ipv6: false, - fast_mode: true, - use_middle_proxy: false, - ad_tag: None, - proxy_secret_path: None, - middle_proxy_nat_ip: None, - middle_proxy_nat_probe: false, - middle_proxy_nat_stun: None, - stun_iface_mismatch_ignore: false, - unknown_dc_log_path: default_unknown_dc_log_path(), - log_level: LogLevel::Normal, - disable_colors: false, - links: LinksConfig::default(), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ServerConfig { - #[serde(default = "default_port")] - pub port: u16, - - #[serde(default)] - pub listen_addr_ipv4: Option, - - #[serde(default)] - pub listen_addr_ipv6: Option, - - #[serde(default)] - pub listen_unix_sock: Option, - - /// Unix socket file permissions (octal, e.g. "0666" or "0777"). - /// Applied via chmod after bind. Default: no change (inherits umask). - #[serde(default)] - pub listen_unix_sock_perm: Option, - - /// Enable TCP listening. Default: true when no unix socket, false when - /// listen_unix_sock is set. Set explicitly to override auto-detection. - #[serde(default)] - pub listen_tcp: Option, - - #[serde(default)] - pub metrics_port: Option, - - #[serde(default = "default_metrics_whitelist")] - pub metrics_whitelist: Vec, - - #[serde(default)] - pub listeners: Vec, -} - -impl Default for ServerConfig { - fn default() -> Self { - Self { - port: default_port(), - listen_addr_ipv4: Some(default_listen_addr()), - listen_addr_ipv6: Some("::".to_string()), - listen_unix_sock: None, - listen_unix_sock_perm: None, - listen_tcp: None, - metrics_port: None, - metrics_whitelist: default_metrics_whitelist(), - listeners: Vec::new(), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TimeoutsConfig { - #[serde(default = "default_handshake_timeout")] - pub client_handshake: u64, - - #[serde(default = "default_connect_timeout")] - pub tg_connect: u64, - - #[serde(default = "default_keepalive")] - pub client_keepalive: u64, - - #[serde(default = "default_ack_timeout")] - pub client_ack: u64, -} - -impl Default for TimeoutsConfig { - fn default() -> Self { - Self { - client_handshake: default_handshake_timeout(), - tg_connect: default_connect_timeout(), - client_keepalive: default_keepalive(), - client_ack: default_ack_timeout(), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AntiCensorshipConfig { - #[serde(default = "default_tls_domain")] - pub tls_domain: String, - - #[serde(default = "default_true")] - pub mask: bool, - - #[serde(default)] - pub mask_host: Option, - - #[serde(default = "default_mask_port")] - pub mask_port: u16, - - #[serde(default)] - pub mask_unix_sock: Option, - - #[serde(default = "default_fake_cert_len")] - pub fake_cert_len: usize, -} - -impl Default for AntiCensorshipConfig { - fn default() -> Self { - Self { - tls_domain: default_tls_domain(), - mask: true, - mask_host: None, - mask_port: default_mask_port(), - mask_unix_sock: None, - fake_cert_len: default_fake_cert_len(), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AccessConfig { - #[serde(default)] - pub users: HashMap, - - #[serde(default)] - pub user_max_tcp_conns: HashMap, - - #[serde(default)] - pub user_expirations: HashMap>, - - #[serde(default)] - pub user_data_quota: HashMap, - - #[serde(default)] - pub user_max_unique_ips: HashMap, - - #[serde(default = "default_replay_check_len")] - pub replay_check_len: usize, - - #[serde(default = "default_replay_window_secs")] - pub replay_window_secs: u64, - - #[serde(default)] - pub ignore_time_skew: bool, -} - -impl Default for AccessConfig { - fn default() -> Self { - let mut users = HashMap::new(); - users.insert( - "default".to_string(), - "00000000000000000000000000000000".to_string(), - ); - Self { - users, - user_max_tcp_conns: HashMap::new(), - user_expirations: HashMap::new(), - user_data_quota: HashMap::new(), - user_max_unique_ips: HashMap::new(), - replay_check_len: default_replay_check_len(), - replay_window_secs: default_replay_window_secs(), - ignore_time_skew: false, - } - } -} - -// ============= Aux Structures ============= - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[serde(tag = "type", rename_all = "lowercase")] -pub enum UpstreamType { - Direct { - #[serde(default)] - interface: Option, - }, - Socks4 { - address: String, - #[serde(default)] - interface: Option, - #[serde(default)] - user_id: Option, - }, - Socks5 { - address: String, - #[serde(default)] - interface: Option, - #[serde(default)] - username: Option, - #[serde(default)] - password: Option, - }, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct UpstreamConfig { - #[serde(flatten)] - pub upstream_type: UpstreamType, - #[serde(default = "default_weight")] - pub weight: u16, - #[serde(default = "default_true")] - pub enabled: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ListenerConfig { - pub ip: IpAddr, - /// IP address or hostname to announce in proxy links. - /// Takes precedence over `announce_ip` if both are set. - #[serde(default)] - pub announce: Option, - /// Deprecated: Use `announce` instead. IP address to announce in proxy links. - /// Migrated to `announce` automatically if `announce` is not set. - #[serde(default)] - pub announce_ip: Option, -} - -// ============= ShowLink ============= - -/// Controls which users' proxy links are displayed at startup. -/// -/// In TOML, this can be: -/// - `show_link = "*"` — show links for all users -/// - `show_link = ["a", "b"]` — show links for specific users -/// - omitted — show no links (default) -#[derive(Debug, Clone)] -pub enum ShowLink { - /// Don't show any links (default when omitted) - None, - /// Show links for all configured users - All, - /// Show links for specific users - Specific(Vec), -} - -impl Default for ShowLink { - fn default() -> Self { - ShowLink::None - } -} - -impl ShowLink { - /// Returns true if no links should be shown - pub fn is_empty(&self) -> bool { - matches!(self, ShowLink::None) || matches!(self, ShowLink::Specific(v) if v.is_empty()) - } - - /// Resolve the list of user names to display, given all configured users - pub fn resolve_users<'a>(&'a self, all_users: &'a HashMap) -> Vec<&'a String> { - match self { - ShowLink::None => vec![], - ShowLink::All => { - let mut names: Vec<&String> = all_users.keys().collect(); - names.sort(); - names - } - ShowLink::Specific(names) => names.iter().collect(), - } - } -} - -impl Serialize for ShowLink { - fn serialize(&self, serializer: S) -> std::result::Result { - match self { - ShowLink::None => Vec::::new().serialize(serializer), - ShowLink::All => serializer.serialize_str("*"), - ShowLink::Specific(v) => v.serialize(serializer), - } - } -} - -impl<'de> Deserialize<'de> for ShowLink { - fn deserialize>(deserializer: D) -> std::result::Result { - use serde::de; - - struct ShowLinkVisitor; - - impl<'de> de::Visitor<'de> for ShowLinkVisitor { - type Value = ShowLink; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str(r#""*" or an array of user names"#) - } - - fn visit_str(self, v: &str) -> std::result::Result { - if v == "*" { - Ok(ShowLink::All) - } else { - Err(de::Error::invalid_value( - de::Unexpected::Str(v), - &r#""*""#, - )) - } - } - - fn visit_seq>(self, mut seq: A) -> std::result::Result { - let mut names = Vec::new(); - while let Some(name) = seq.next_element::()? { - names.push(name); - } - if names.is_empty() { - Ok(ShowLink::None) - } else { - Ok(ShowLink::Specific(names)) - } - } - } - - deserializer.deserialize_any(ShowLinkVisitor) - } -} - -// ============= Main Config ============= - -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct ProxyConfig { - #[serde(default)] - pub general: GeneralConfig, - - #[serde(default)] - pub network: NetworkConfig, - - #[serde(default)] - pub server: ServerConfig, - - #[serde(default)] - pub timeouts: TimeoutsConfig, - - #[serde(default)] - pub censorship: AntiCensorshipConfig, - - #[serde(default)] - pub access: AccessConfig, - - #[serde(default)] - pub upstreams: Vec, - - #[serde(default)] - pub show_link: ShowLink, - - /// DC address overrides for non-standard DCs (CDN, media, test, etc.) - /// Keys are DC indices as strings, values are one or more \"ip:port\" addresses. - /// Matches the C implementation's `proxy_for :` config directive. - /// Example in config.toml: - /// [dc_overrides] - /// \"203\" = [\"149.154.175.100:443\", \"91.105.192.100:443\"] - #[serde(default, deserialize_with = "deserialize_dc_overrides")] - pub dc_overrides: HashMap>, - - /// Default DC index (1-5) for unmapped non-standard DCs. - /// Matches the C implementation's `default ` config directive. - /// If not set, defaults to 2 (matching Telegram's official `default 2;` in proxy-multi.conf). - #[serde(default)] - pub default_dc: Option, -} - -impl ProxyConfig { - pub fn load>(path: P) -> Result { - let content = - std::fs::read_to_string(path).map_err(|e| ProxyError::Config(e.to_string()))?; - - let mut config: ProxyConfig = - toml::from_str(&content).map_err(|e| ProxyError::Config(e.to_string()))?; - - // Validate secrets - for (user, secret) in &config.access.users { - if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 { - return Err(ProxyError::InvalidSecret { - user: user.clone(), - reason: "Must be 32 hex characters".to_string(), - }); - } - } - - // Validate tls_domain - if config.censorship.tls_domain.is_empty() { - return Err(ProxyError::Config("tls_domain cannot be empty".to_string())); - } - - // Validate mask_unix_sock - if let Some(ref sock_path) = config.censorship.mask_unix_sock { - if sock_path.is_empty() { - return Err(ProxyError::Config( - "mask_unix_sock cannot be empty".to_string(), - )); - } - #[cfg(unix)] - if sock_path.len() > 107 { - return Err(ProxyError::Config(format!( - "mask_unix_sock path too long: {} bytes (max 107)", - sock_path.len() - ))); - } - #[cfg(not(unix))] - return Err(ProxyError::Config( - "mask_unix_sock is only supported on Unix platforms".to_string(), - )); - - if config.censorship.mask_host.is_some() { - return Err(ProxyError::Config( - "mask_unix_sock and mask_host are mutually exclusive".to_string(), - )); - } - } - - // Default mask_host to tls_domain if not set and no unix socket configured - if config.censorship.mask_host.is_none() && config.censorship.mask_unix_sock.is_none() { - config.censorship.mask_host = Some(config.censorship.tls_domain.clone()); - } - - // Migration: prefer_ipv6 -> network.prefer - if config.general.prefer_ipv6 { - if config.network.prefer == 4 { - config.network.prefer = 6; - } - warn!("prefer_ipv6 is deprecated, use [network].prefer = 6"); - } - - validate_network_cfg(&mut config.network)?; - - // Random fake_cert_len - use rand::Rng; - config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096); - - // Resolve listen_tcp: explicit value wins, otherwise auto-detect. - // If unix socket is set → TCP only when listen_addr_ipv4 or listeners are explicitly provided. - // If no unix socket → TCP always (backward compat). - let listen_tcp = config.server.listen_tcp.unwrap_or_else(|| { - if config.server.listen_unix_sock.is_some() { - // Unix socket present: TCP only if user explicitly set addresses or listeners - config.server.listen_addr_ipv4.is_some() - || !config.server.listeners.is_empty() - } else { - true - } - }); - - // Migration: Populate listeners if empty (skip when listen_tcp = false) - if config.server.listeners.is_empty() && listen_tcp { - let ipv4_str = config.server.listen_addr_ipv4 - .as_deref() - .unwrap_or("0.0.0.0"); - if let Ok(ipv4) = ipv4_str.parse::() { - config.server.listeners.push(ListenerConfig { - ip: ipv4, - announce: None, - announce_ip: None, - }); - } - if let Some(ipv6_str) = &config.server.listen_addr_ipv6 { - if let Ok(ipv6) = ipv6_str.parse::() { - config.server.listeners.push(ListenerConfig { - ip: ipv6, - announce: None, - announce_ip: None, - }); - } - } - } - - // Migration: announce_ip → announce for each listener - for listener in &mut config.server.listeners { - if listener.announce.is_none() && listener.announce_ip.is_some() { - listener.announce = Some(listener.announce_ip.unwrap().to_string()); - } - } - - // Migration: show_link (top-level) → general.links.show - if !config.show_link.is_empty() && config.general.links.show.is_empty() { - config.general.links.show = config.show_link.clone(); - } - - // Migration: Populate upstreams if empty (Default Direct) - if config.upstreams.is_empty() { - config.upstreams.push(UpstreamConfig { - upstream_type: UpstreamType::Direct { interface: None }, - weight: 1, - enabled: true, - }); - } - - // Ensure default DC203 override is present. - config - .dc_overrides - .entry("203".to_string()) - .or_insert_with(|| vec!["91.105.192.100:443".to_string()]); - - Ok(config) - } - - pub fn validate(&self) -> Result<()> { - if self.access.users.is_empty() { - return Err(ProxyError::Config("No users configured".to_string())); - } - - if !self.general.modes.classic && !self.general.modes.secure && !self.general.modes.tls { - return Err(ProxyError::Config("No modes enabled".to_string())); - } - - if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') { - return Err(ProxyError::Config(format!( - "Invalid tls_domain: '{}'. Must be a valid domain name", - self.censorship.tls_domain - ))); - } - - if let Some(tag) = &self.general.ad_tag { - let zeros = "00000000000000000000000000000000"; - if tag == zeros { - warn!("ad_tag is all zeros; register a valid proxy tag via @MTProxybot to enable sponsored channel"); - } - if tag.len() != 32 || tag.chars().any(|c| !c.is_ascii_hexdigit()) { - warn!("ad_tag is not a 32-char hex string; ensure you use value issued by @MTProxybot"); - } - } - - Ok(()) - } -} +pub use load::ProxyConfig; +pub use types::*; diff --git a/src/config/types.rs b/src/config/types.rs new file mode 100644 index 0000000..c961808 --- /dev/null +++ b/src/config/types.rs @@ -0,0 +1,504 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::net::IpAddr; + +use super::defaults::*; + +// ============= Log Level ============= + +/// Logging verbosity level. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum LogLevel { + /// All messages including trace (trace + debug + info + warn + error). + Debug, + /// Detailed operational logs (debug + info + warn + error). + Verbose, + /// Standard operational logs (info + warn + error). + #[default] + Normal, + /// Minimal output: only warnings and errors (warn + error). + /// Startup messages (config, DC connectivity, proxy links) are always shown + /// via info! before the filter is applied. + Silent, +} + +impl LogLevel { + /// Convert to tracing EnvFilter directive string. + pub fn to_filter_str(&self) -> &'static str { + match self { + LogLevel::Debug => "trace", + LogLevel::Verbose => "debug", + LogLevel::Normal => "info", + LogLevel::Silent => "warn", + } + } + + /// Parse from a loose string (CLI argument). + pub fn from_str_loose(s: &str) -> Self { + match s.to_lowercase().as_str() { + "debug" | "trace" => LogLevel::Debug, + "verbose" => LogLevel::Verbose, + "normal" | "info" => LogLevel::Normal, + "silent" | "quiet" | "error" | "warn" => LogLevel::Silent, + _ => LogLevel::Normal, + } + } +} + +impl std::fmt::Display for LogLevel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + LogLevel::Debug => write!(f, "debug"), + LogLevel::Verbose => write!(f, "verbose"), + LogLevel::Normal => write!(f, "normal"), + LogLevel::Silent => write!(f, "silent"), + } + } +} + +// ============= Sub-Configs ============= + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProxyModes { + #[serde(default)] + pub classic: bool, + #[serde(default)] + pub secure: bool, + #[serde(default = "default_true")] + pub tls: bool, +} + +impl Default for ProxyModes { + fn default() -> Self { + Self { + classic: true, + secure: true, + tls: true, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NetworkConfig { + #[serde(default = "default_true")] + pub ipv4: bool, + + /// None = auto-detect IPv6 availability. + #[serde(default)] + pub ipv6: Option, + + /// 4 or 6. + #[serde(default = "default_prefer_4")] + pub prefer: u8, + + #[serde(default)] + pub multipath: bool, +} + +impl Default for NetworkConfig { + fn default() -> Self { + Self { + ipv4: true, + ipv6: None, + prefer: 4, + multipath: false, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GeneralConfig { + #[serde(default)] + pub modes: ProxyModes, + + #[serde(default)] + pub prefer_ipv6: bool, + + #[serde(default = "default_true")] + pub fast_mode: bool, + + #[serde(default)] + pub use_middle_proxy: bool, + + #[serde(default)] + pub ad_tag: Option, + + /// Path to proxy-secret binary file (auto-downloaded if absent). + /// Infrastructure secret from https://core.telegram.org/getProxySecret. + #[serde(default)] + pub proxy_secret_path: Option, + + /// Public IP override for middle-proxy NAT environments. + /// When set, this IP is used in ME key derivation and RPC_PROXY_REQ "our_addr". + #[serde(default)] + pub middle_proxy_nat_ip: Option, + + /// Enable STUN-based NAT probing to discover public IP:port for ME KDF. + #[serde(default)] + pub middle_proxy_nat_probe: bool, + + /// Optional STUN server address (host:port) for NAT probing. + #[serde(default)] + pub middle_proxy_nat_stun: Option, + + /// Ignore STUN/interface IP mismatch (keep using Middle Proxy even if NAT detected). + #[serde(default)] + pub stun_iface_mismatch_ignore: bool, + + /// Log unknown (non-standard) DC requests to a file (default: unknown-dc.txt). Set to null to disable. + #[serde(default = "default_unknown_dc_log_path")] + pub unknown_dc_log_path: Option, + + #[serde(default)] + pub log_level: LogLevel, + + /// Disable colored output in logs (useful for files/systemd). + #[serde(default)] + pub disable_colors: bool, + + /// [general.links] — proxy link generation overrides. + #[serde(default)] + pub links: LinksConfig, +} + +impl Default for GeneralConfig { + fn default() -> Self { + Self { + modes: ProxyModes::default(), + prefer_ipv6: false, + fast_mode: true, + use_middle_proxy: false, + ad_tag: None, + proxy_secret_path: None, + middle_proxy_nat_ip: None, + middle_proxy_nat_probe: false, + middle_proxy_nat_stun: None, + stun_iface_mismatch_ignore: false, + unknown_dc_log_path: default_unknown_dc_log_path(), + log_level: LogLevel::Normal, + disable_colors: false, + links: LinksConfig::default(), + } + } +} + +/// `[general.links]` — proxy link generation settings. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct LinksConfig { + /// List of usernames whose tg:// links to display at startup. + /// `"*"` = all users, `["alice", "bob"]` = specific users. + #[serde(default)] + pub show: ShowLink, + + /// Public hostname/IP for tg:// link generation (overrides detected IP). + #[serde(default)] + pub public_host: Option, + + /// Public port for tg:// link generation (overrides server.port). + #[serde(default)] + pub public_port: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ServerConfig { + #[serde(default = "default_port")] + pub port: u16, + + #[serde(default)] + pub listen_addr_ipv4: Option, + + #[serde(default)] + pub listen_addr_ipv6: Option, + + #[serde(default)] + pub listen_unix_sock: Option, + + /// Unix socket file permissions (octal, e.g. "0666" or "0777"). + /// Applied via chmod after bind. Default: no change (inherits umask). + #[serde(default)] + pub listen_unix_sock_perm: Option, + + /// Enable TCP listening. Default: true when no unix socket, false when + /// listen_unix_sock is set. Set explicitly to override auto-detection. + #[serde(default)] + pub listen_tcp: Option, + + #[serde(default)] + pub metrics_port: Option, + + #[serde(default = "default_metrics_whitelist")] + pub metrics_whitelist: Vec, + + #[serde(default)] + pub listeners: Vec, +} + +impl Default for ServerConfig { + fn default() -> Self { + Self { + port: default_port(), + listen_addr_ipv4: Some(default_listen_addr()), + listen_addr_ipv6: Some("::".to_string()), + listen_unix_sock: None, + listen_unix_sock_perm: None, + listen_tcp: None, + metrics_port: None, + metrics_whitelist: default_metrics_whitelist(), + listeners: Vec::new(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TimeoutsConfig { + #[serde(default = "default_handshake_timeout")] + pub client_handshake: u64, + + #[serde(default = "default_connect_timeout")] + pub tg_connect: u64, + + #[serde(default = "default_keepalive")] + pub client_keepalive: u64, + + #[serde(default = "default_ack_timeout")] + pub client_ack: u64, +} + +impl Default for TimeoutsConfig { + fn default() -> Self { + Self { + client_handshake: default_handshake_timeout(), + tg_connect: default_connect_timeout(), + client_keepalive: default_keepalive(), + client_ack: default_ack_timeout(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AntiCensorshipConfig { + #[serde(default = "default_tls_domain")] + pub tls_domain: String, + + #[serde(default = "default_true")] + pub mask: bool, + + #[serde(default)] + pub mask_host: Option, + + #[serde(default = "default_mask_port")] + pub mask_port: u16, + + #[serde(default)] + pub mask_unix_sock: Option, + + #[serde(default = "default_fake_cert_len")] + pub fake_cert_len: usize, +} + +impl Default for AntiCensorshipConfig { + fn default() -> Self { + Self { + tls_domain: default_tls_domain(), + mask: true, + mask_host: None, + mask_port: default_mask_port(), + mask_unix_sock: None, + fake_cert_len: default_fake_cert_len(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AccessConfig { + #[serde(default)] + pub users: HashMap, + + #[serde(default)] + pub user_max_tcp_conns: HashMap, + + #[serde(default)] + pub user_expirations: HashMap>, + + #[serde(default)] + pub user_data_quota: HashMap, + + #[serde(default)] + pub user_max_unique_ips: HashMap, + + #[serde(default = "default_replay_check_len")] + pub replay_check_len: usize, + + #[serde(default = "default_replay_window_secs")] + pub replay_window_secs: u64, + + #[serde(default)] + pub ignore_time_skew: bool, +} + +impl Default for AccessConfig { + fn default() -> Self { + let mut users = HashMap::new(); + users.insert( + "default".to_string(), + "00000000000000000000000000000000".to_string(), + ); + Self { + users, + user_max_tcp_conns: HashMap::new(), + user_expirations: HashMap::new(), + user_data_quota: HashMap::new(), + user_max_unique_ips: HashMap::new(), + replay_check_len: default_replay_check_len(), + replay_window_secs: default_replay_window_secs(), + ignore_time_skew: false, + } + } +} + +// ============= Aux Structures ============= + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum UpstreamType { + Direct { + #[serde(default)] + interface: Option, + }, + Socks4 { + address: String, + #[serde(default)] + interface: Option, + #[serde(default)] + user_id: Option, + }, + Socks5 { + address: String, + #[serde(default)] + interface: Option, + #[serde(default)] + username: Option, + #[serde(default)] + password: Option, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UpstreamConfig { + #[serde(flatten)] + pub upstream_type: UpstreamType, + #[serde(default = "default_weight")] + pub weight: u16, + #[serde(default = "default_true")] + pub enabled: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ListenerConfig { + pub ip: IpAddr, + /// IP address or hostname to announce in proxy links. + /// Takes precedence over `announce_ip` if both are set. + #[serde(default)] + pub announce: Option, + /// Deprecated: Use `announce` instead. IP address to announce in proxy links. + /// Migrated to `announce` automatically if `announce` is not set. + #[serde(default)] + pub announce_ip: Option, +} + +// ============= ShowLink ============= + +/// Controls which users' proxy links are displayed at startup. +/// +/// In TOML, this can be: +/// - `show_link = "*"` — show links for all users +/// - `show_link = ["a", "b"]` — show links for specific users +/// - omitted — show no links (default) +#[derive(Debug, Clone)] +pub enum ShowLink { + /// Don't show any links (default when omitted). + None, + /// Show links for all configured users. + All, + /// Show links for specific users. + Specific(Vec), +} + +impl Default for ShowLink { + fn default() -> Self { + ShowLink::None + } +} + +impl ShowLink { + /// Returns true if no links should be shown. + pub fn is_empty(&self) -> bool { + matches!(self, ShowLink::None) || matches!(self, ShowLink::Specific(v) if v.is_empty()) + } + + /// Resolve the list of user names to display, given all configured users. + pub fn resolve_users<'a>(&'a self, all_users: &'a HashMap) -> Vec<&'a String> { + match self { + ShowLink::None => vec![], + ShowLink::All => { + let mut names: Vec<&String> = all_users.keys().collect(); + names.sort(); + names + } + ShowLink::Specific(names) => names.iter().collect(), + } + } +} + +impl Serialize for ShowLink { + fn serialize(&self, serializer: S) -> std::result::Result { + match self { + ShowLink::None => Vec::::new().serialize(serializer), + ShowLink::All => serializer.serialize_str("*"), + ShowLink::Specific(v) => v.serialize(serializer), + } + } +} + +impl<'de> Deserialize<'de> for ShowLink { + fn deserialize>(deserializer: D) -> std::result::Result { + use serde::de; + + struct ShowLinkVisitor; + + impl<'de> de::Visitor<'de> for ShowLinkVisitor { + type Value = ShowLink; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str(r#""*" or an array of user names"#) + } + + fn visit_str(self, v: &str) -> std::result::Result { + if v == "*" { + Ok(ShowLink::All) + } else { + Err(de::Error::invalid_value( + de::Unexpected::Str(v), + &r#""*""#, + )) + } + } + + fn visit_seq>(self, mut seq: A) -> std::result::Result { + let mut names = Vec::new(); + while let Some(name) = seq.next_element::()? { + names.push(name); + } + if names.is_empty() { + Ok(ShowLink::None) + } else { + Ok(ShowLink::Specific(names)) + } + } + } + + deserializer.deserialize_any(ShowLinkVisitor) + } +} diff --git a/src/main.rs b/src/main.rs index 9865558..57b993d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -326,6 +326,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai config.general.middle_proxy_nat_ip, config.general.middle_proxy_nat_probe, config.general.middle_proxy_nat_stun.clone(), + probe.detected_ipv6, cfg_v4.map.clone(), cfg_v6.map.clone(), cfg_v4.default_dc.or(cfg_v6.default_dc), diff --git a/src/network/probe.rs b/src/network/probe.rs index 2a220f5..d290ac1 100644 --- a/src/network/probe.rs +++ b/src/network/probe.rs @@ -58,7 +58,13 @@ pub async fn run_probe(config: &NetworkConfig, stun_addr: Option, nat_pr let stun_server = stun_addr.unwrap_or_else(|| "stun.l.google.com:19302".to_string()); let stun_res = if nat_probe { - stun_probe_dual(&stun_server).await? + match stun_probe_dual(&stun_server).await { + Ok(res) => res, + Err(e) => { + warn!(error = %e, "STUN probe failed, continuing without reflection"); + DualStunResult::default() + } + } } else { DualStunResult::default() }; diff --git a/src/network/stun.rs b/src/network/stun.rs index e1c811d..251454e 100644 --- a/src/network/stun.rs +++ b/src/network/stun.rs @@ -1,6 +1,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use tokio::net::{lookup_host, UdpSocket}; +use tokio::time::{timeout, Duration, sleep}; use crate::error::{ProxyError, Result}; @@ -63,22 +64,36 @@ pub async fn stun_probe_family(stun_addr: &str, family: IpFamily) -> Result n, + Ok(Err(e)) => return Err(ProxyError::Proxy(format!("STUN recv failed: {e}"))), + Err(_) => { + attempt += 1; + if attempt >= 3 { + return Ok(None); + } + sleep(backoff).await; + backoff *= 2; + continue; + } + }; + + if n < 20 { + return Ok(None); + } + + let magic = 0x2112A442u32.to_be_bytes(); + let txid = &req[8..20]; let mut idx = 20; while idx + 4 <= n { let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().unwrap()); @@ -160,6 +175,8 @@ pub async fn stun_probe_family(stun_addr: &str, family: IpFamily) -> Result= 16 { self.iv.copy_from_slice(&buf[buf.len() - 16..]); } - self.writer.write_all(&buf).await.map_err(ProxyError::Io)?; + self.writer.write_all(&buf).await.map_err(ProxyError::Io) + } + + pub(crate) async fn send_and_flush(&mut self, payload: &[u8]) -> Result<()> { + self.send(payload).await?; self.writer.flush().await.map_err(ProxyError::Io) } } diff --git a/src/transport/middle_proxy/config_updater.rs b/src/transport/middle_proxy/config_updater.rs index 8ac6986..3c36820 100644 --- a/src/transport/middle_proxy/config_updater.rs +++ b/src/transport/middle_proxy/config_updater.rs @@ -3,7 +3,6 @@ use std::net::IpAddr; use std::sync::Arc; use std::time::Duration; -use regex::Regex; use httpdate; use tracing::{debug, info, warn}; @@ -20,6 +19,45 @@ pub struct ProxyConfigData { pub default_dc: Option, } +fn parse_host_port(s: &str) -> Option<(IpAddr, u16)> { + if let Some(bracket_end) = s.rfind(']') { + if s.starts_with('[') && bracket_end + 1 < s.len() && s.as_bytes().get(bracket_end + 1) == Some(&b':') { + let host = &s[1..bracket_end]; + let port_str = &s[bracket_end + 2..]; + let ip = host.parse::().ok()?; + let port = port_str.parse::().ok()?; + return Some((ip, port)); + } + } + + let idx = s.rfind(':')?; + let host = &s[..idx]; + let port_str = &s[idx + 1..]; + let ip = host.parse::().ok()?; + let port = port_str.parse::().ok()?; + Some((ip, port)) +} + +fn parse_proxy_line(line: &str) -> Option<(i32, IpAddr, u16)> { + // Accepts lines like: + // proxy_for 4 91.108.4.195:8888; + // proxy_for 2 [2001:67c:04e8:f002::d]:80; + // proxy_for 2 2001:67c:04e8:f002::d:80; + let trimmed = line.trim(); + if !trimmed.starts_with("proxy_for") { + return None; + } + // Capture everything between dc and trailing ';' + let without_prefix = trimmed.trim_start_matches("proxy_for").trim(); + let mut parts = without_prefix.split_whitespace(); + let dc_str = parts.next()?; + let rest = parts.next()?; + let host_port = rest.trim_end_matches(';'); + let dc = dc_str.parse::().ok()?; + let (ip, port) = parse_host_port(host_port)?; + Some((dc, ip, port)) +} + pub async fn fetch_proxy_config(url: &str) -> Result { let resp = reqwest::get(url) .await @@ -48,26 +86,26 @@ pub async fn fetch_proxy_config(url: &str) -> Result { .await .map_err(|e| crate::error::ProxyError::Proxy(format!("fetch_proxy_config read failed: {e}")))?; - let re_proxy = Regex::new(r"proxy_for\s+(-?\d+)\s+([^\s:]+):(\d+)\s*;").unwrap(); - let re_default = Regex::new(r"default\s+(-?\d+)\s*;").unwrap(); - let mut map: HashMap> = HashMap::new(); - for cap in re_proxy.captures_iter(&text) { - if let (Some(dc), Some(host), Some(port)) = (cap.get(1), cap.get(2), cap.get(3)) { - if let Ok(dc_idx) = dc.as_str().parse::() { - if let Ok(ip) = host.as_str().parse::() { - if let Ok(port_num) = port.as_str().parse::() { - map.entry(dc_idx).or_default().push((ip, port_num)); - } - } - } + for line in text.lines() { + if let Some((dc, ip, port)) = parse_proxy_line(line) { + map.entry(dc).or_default().push((ip, port)); } } - let default_dc = re_default - .captures(&text) - .and_then(|c| c.get(1)) - .and_then(|m| m.as_str().parse::().ok()); + let default_dc = text + .lines() + .find_map(|l| { + let t = l.trim(); + if let Some(rest) = t.strip_prefix("default") { + return rest + .trim() + .trim_end_matches(';') + .parse::() + .ok(); + } + None + }); Ok(ProxyConfigData { map, default_dc }) } @@ -111,3 +149,35 @@ pub async fn me_config_updater(pool: Arc, rng: Arc, interv } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_ipv6_bracketed() { + let line = "proxy_for 2 [2001:67c:04e8:f002::d]:80;"; + let res = parse_proxy_line(line).unwrap(); + assert_eq!(res.0, 2); + assert_eq!(res.1, "2001:67c:04e8:f002::d".parse::().unwrap()); + assert_eq!(res.2, 80); + } + + #[test] + fn parse_ipv6_plain() { + let line = "proxy_for 2 2001:67c:04e8:f002::d:80;"; + let res = parse_proxy_line(line).unwrap(); + assert_eq!(res.0, 2); + assert_eq!(res.1, "2001:67c:04e8:f002::d".parse::().unwrap()); + assert_eq!(res.2, 80); + } + + #[test] + fn parse_ipv4() { + let line = "proxy_for 4 91.108.4.195:8888;"; + let res = parse_proxy_line(line).unwrap(); + assert_eq!(res.0, 4); + assert_eq!(res.1, "91.108.4.195".parse::().unwrap()); + assert_eq!(res.2, 8888); + } +} diff --git a/src/transport/middle_proxy/handshake.rs b/src/transport/middle_proxy/handshake.rs index a860d01..1c08508 100644 --- a/src/transport/middle_proxy/handshake.rs +++ b/src/transport/middle_proxy/handshake.rs @@ -10,7 +10,7 @@ use std::os::raw::c_int; use bytes::BytesMut; use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}; -use tokio::net::TcpStream; +use tokio::net::{TcpStream, TcpSocket}; use tokio::time::timeout; use tracing::{debug, info, warn}; @@ -44,7 +44,28 @@ impl MePool { /// TCP connect with timeout + return RTT in milliseconds. pub(crate) async fn connect_tcp(&self, addr: SocketAddr) -> Result<(TcpStream, f64)> { let start = Instant::now(); - let stream = timeout(Duration::from_secs(ME_CONNECT_TIMEOUT_SECS), TcpStream::connect(addr)) + let connect_fut = async { + if addr.is_ipv6() { + if let Some(v6) = self.detected_ipv6 { + match TcpSocket::new_v6() { + Ok(sock) => { + if let Err(e) = sock.bind(SocketAddr::new(IpAddr::V6(v6), 0)) { + debug!(error = %e, bind_ip = %v6, "ME IPv6 bind failed, falling back to default bind"); + } else { + match sock.connect(addr).await { + Ok(stream) => return Ok(stream), + Err(e) => debug!(error = %e, target = %addr, "ME IPv6 bound connect failed, retrying default connect"), + } + } + } + Err(e) => debug!(error = %e, "ME IPv6 socket creation failed, falling back to default connect"), + } + } + } + TcpStream::connect(addr).await + }; + + let stream = timeout(Duration::from_secs(ME_CONNECT_TIMEOUT_SECS), connect_fut) .await .map_err(|_| ProxyError::ConnectionTimeout { addr: addr.to_string() })??; let connect_ms = start.elapsed().as_secs_f64() * 1000.0; diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 9771b6b..8510dfd 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -1,5 +1,5 @@ use std::collections::HashMap; -use std::net::{IpAddr, SocketAddr}; +use std::net::{IpAddr, Ipv6Addr, SocketAddr}; use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, Ordering}; use bytes::BytesMut; @@ -32,6 +32,7 @@ pub struct MeWriter { pub writer: Arc>, pub cancel: CancellationToken, pub degraded: Arc, + pub draining: Arc, } pub struct MePool { @@ -46,6 +47,7 @@ pub struct MePool { pub(super) nat_ip_detected: Arc>>, pub(super) nat_probe: bool, pub(super) nat_stun: Option, + pub(super) detected_ipv6: Option, pub(super) proxy_map_v4: Arc>>>, pub(super) proxy_map_v6: Arc>>>, pub(super) default_dc: AtomicI32, @@ -69,6 +71,7 @@ impl MePool { nat_ip: Option, nat_probe: bool, nat_stun: Option, + detected_ipv6: Option, proxy_map_v4: HashMap>, proxy_map_v6: HashMap>, default_dc: Option, @@ -87,6 +90,7 @@ impl MePool { nat_ip_detected: Arc::new(RwLock::new(None)), nat_probe, nat_stun, + detected_ipv6, pool_size: 2, proxy_map_v4: Arc::new(RwLock::new(proxy_map_v4)), proxy_map_v6: Arc::new(RwLock::new(proxy_map_v6)), @@ -294,6 +298,7 @@ impl MePool { let writer_id = self.next_writer_id.fetch_add(1, Ordering::Relaxed); let cancel = CancellationToken::new(); let degraded = Arc::new(AtomicBool::new(false)); + let draining = Arc::new(AtomicBool::new(false)); let rpc_w = Arc::new(Mutex::new(RpcWriter { writer: hs.wr, key: hs.write_key, @@ -306,6 +311,7 @@ impl MePool { writer: rpc_w.clone(), cancel: cancel.clone(), degraded: degraded.clone(), + draining: draining.clone(), }; self.writers.write().await.push(writer.clone()); @@ -336,7 +342,7 @@ impl MePool { ) .await; if let Some(pool) = pool.upgrade() { - pool.remove_writer_and_reroute(writer_id).await; + pool.remove_writer_and_close_clients(writer_id).await; } if let Err(e) = res { warn!(error = %e, "ME reader ended"); @@ -368,11 +374,11 @@ impl MePool { tracker.insert(sent_id, (std::time::Instant::now(), writer_id)); } ping_id = ping_id.wrapping_add(1); - if let Err(e) = rpc_w_ping.lock().await.send(&p).await { + if let Err(e) = rpc_w_ping.lock().await.send_and_flush(&p).await { debug!(error = %e, "Active ME ping failed, removing dead writer"); cancel_ping.cancel(); if let Some(pool) = pool_ping.upgrade() { - pool.remove_writer_and_reroute(writer_id).await; + pool.remove_writer_and_close_clients(writer_id).await; } break; } @@ -405,12 +411,11 @@ impl MePool { warn!(dc = %dc, "All ME servers for DC failed at init"); } - pub(crate) async fn remove_writer_and_reroute(&self, writer_id: u64) { - let mut queue = self.remove_writer_only(writer_id).await; - while let Some(bound) = queue.pop() { - if !self.reroute_conn(&bound, &mut queue).await { - let _ = self.registry.route(bound.conn_id, super::MeResponse::Close).await; - } + pub(crate) async fn remove_writer_and_close_clients(&self, writer_id: u64) { + let conns = self.remove_writer_only(writer_id).await; + for bound in conns { + let _ = self.registry.route(bound.conn_id, super::MeResponse::Close).await; + let _ = self.registry.unregister(bound.conn_id).await; } } @@ -425,79 +430,28 @@ impl MePool { self.registry.writer_lost(writer_id).await } - async fn reroute_conn(&self, bound: &BoundConn, backlog: &mut Vec) -> bool { - let payload = super::wire::build_proxy_req_payload( - bound.conn_id, - bound.meta.client_addr, - bound.meta.our_addr, - &[], - self.proxy_tag.as_deref(), - bound.meta.proto_flags, - ); - - let mut attempts = 0; - loop { - let writers_snapshot = { - let ws = self.writers.read().await; - if ws.is_empty() { - return false; - } - ws.clone() - }; - let mut candidates = self.candidate_indices_for_dc(&writers_snapshot, bound.meta.target_dc).await; - if candidates.is_empty() { - return false; - } - candidates.sort_by_key(|idx| { - writers_snapshot[*idx] - .degraded - .load(Ordering::Relaxed) - .then_some(1usize) - .unwrap_or(0) - }); - let start = self.rr.fetch_add(1, Ordering::Relaxed) as usize % candidates.len(); - - for offset in 0..candidates.len() { - let idx = candidates[(start + offset) % candidates.len()]; - let w = &writers_snapshot[idx]; - if let Ok(mut guard) = w.writer.try_lock() { - let send_res = guard.send(&payload).await; - drop(guard); - match send_res { - Ok(()) => { - self.registry - .bind_writer(bound.conn_id, w.id, w.writer.clone(), bound.meta.clone()) - .await; - return true; - } - Err(e) => { - warn!(error = %e, writer_id = w.id, "ME reroute send failed"); - backlog.extend(self.remove_writer_only(w.id).await); - } - } - continue; - } - } - - let w = writers_snapshot[candidates[start]].clone(); - match w.writer.lock().await.send(&payload).await { - Ok(()) => { - self.registry - .bind_writer(bound.conn_id, w.id, w.writer.clone(), bound.meta.clone()) - .await; - return true; - } - Err(e) => { - warn!(error = %e, writer_id = w.id, "ME reroute send failed (blocking)"); - backlog.extend(self.remove_writer_only(w.id).await); - } - } - - attempts += 1; - if attempts > 3 { - return false; + pub(crate) async fn mark_writer_draining(self: &Arc, writer_id: u64) { + { + let mut ws = self.writers.write().await; + if let Some(w) = ws.iter_mut().find(|w| w.id == writer_id) { + w.draining.store(true, Ordering::Relaxed); } } + + let pool = Arc::downgrade(self); + tokio::spawn(async move { + loop { + if let Some(p) = pool.upgrade() { + if p.registry.is_writer_empty(writer_id).await { + let _ = p.remove_writer_only(writer_id).await; + break; + } + tokio::time::sleep(Duration::from_secs(1)).await; + } else { + break; + } + } + }); } } diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index b53ddef..fb40fdb 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -136,7 +136,7 @@ pub(crate) async fn reader_loop( let mut pong = Vec::with_capacity(12); pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes()); pong.extend_from_slice(&ping_id.to_le_bytes()); - if let Err(e) = writer.lock().await.send(&pong).await { + if let Err(e) = writer.lock().await.send_and_flush(&pong).await { warn!(error = %e, "PONG send failed"); break; } @@ -176,7 +176,7 @@ async fn send_close_conn(writer: &Arc>, conn_id: u64) { p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes()); p.extend_from_slice(&conn_id.to_le_bytes()); - if let Err(e) = writer.lock().await.send(&p).await { + if let Err(e) = writer.lock().await.send_and_flush(&p).await { debug!(conn_id, error = %e, "Failed to send RPC_CLOSE_CONN"); } } diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index 9905d1d..04f8baa 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; @@ -28,12 +28,28 @@ pub struct ConnWriter { pub writer: Arc>, } +struct RegistryInner { + map: HashMap>, + writers: HashMap>>, + writer_for_conn: HashMap, + conns_for_writer: HashMap>, + meta: HashMap, +} + +impl RegistryInner { + fn new() -> Self { + Self { + map: HashMap::new(), + writers: HashMap::new(), + writer_for_conn: HashMap::new(), + conns_for_writer: HashMap::new(), + meta: HashMap::new(), + } + } +} + pub struct ConnRegistry { - map: RwLock>>, - writers: RwLock>>>, - writer_for_conn: RwLock>, - conns_for_writer: RwLock>>, - meta: RwLock>, + inner: RwLock, next_id: AtomicU64, } @@ -41,11 +57,7 @@ impl ConnRegistry { pub fn new() -> Self { let start = rand::random::() | 1; Self { - map: RwLock::new(HashMap::new()), - writers: RwLock::new(HashMap::new()), - writer_for_conn: RwLock::new(HashMap::new()), - conns_for_writer: RwLock::new(HashMap::new()), - meta: RwLock::new(HashMap::new()), + inner: RwLock::new(RegistryInner::new()), next_id: AtomicU64::new(start), } } @@ -53,23 +65,27 @@ impl ConnRegistry { pub async fn register(&self) -> (u64, mpsc::Receiver) { let id = self.next_id.fetch_add(1, Ordering::Relaxed); let (tx, rx) = mpsc::channel(1024); - self.map.write().await.insert(id, tx); + self.inner.write().await.map.insert(id, tx); (id, rx) } - pub async fn unregister(&self, id: u64) { - self.map.write().await.remove(&id); - self.meta.write().await.remove(&id); - if let Some(writer_id) = self.writer_for_conn.write().await.remove(&id) { - if let Some(list) = self.conns_for_writer.write().await.get_mut(&writer_id) { - list.retain(|c| *c != id); + /// Unregister connection, returning associated writer_id if any. + pub async fn unregister(&self, id: u64) -> Option { + let mut inner = self.inner.write().await; + inner.map.remove(&id); + inner.meta.remove(&id); + if let Some(writer_id) = inner.writer_for_conn.remove(&id) { + if let Some(set) = inner.conns_for_writer.get_mut(&writer_id) { + set.remove(&id); } + return Some(writer_id); } + None } pub async fn route(&self, id: u64, resp: MeResponse) -> bool { - let m = self.map.read().await; - if let Some(tx) = m.get(&id) { + let inner = self.inner.read().await; + if let Some(tx) = inner.map.get(&id) { tx.try_send(resp).is_ok() } else { false @@ -83,40 +99,38 @@ impl ConnRegistry { writer: Arc>, meta: ConnMeta, ) { - self.meta.write().await.entry(conn_id).or_insert(meta); - self.writer_for_conn.write().await.insert(conn_id, writer_id); - self.writers.write().await.entry(writer_id).or_insert_with(|| writer.clone()); - self.conns_for_writer - .write() - .await + let mut inner = self.inner.write().await; + inner.meta.entry(conn_id).or_insert(meta); + inner.writer_for_conn.insert(conn_id, writer_id); + inner.writers.entry(writer_id).or_insert_with(|| writer.clone()); + inner + .conns_for_writer .entry(writer_id) - .or_insert_with(Vec::new) - .push(conn_id); + .or_insert_with(HashSet::new) + .insert(conn_id); } pub async fn get_writer(&self, conn_id: u64) -> Option { - let writer_id = { - let guard = self.writer_for_conn.read().await; - guard.get(&conn_id).cloned() - }?; - let writer = { - let guard = self.writers.read().await; - guard.get(&writer_id).cloned() - }?; + let inner = self.inner.read().await; + let writer_id = inner.writer_for_conn.get(&conn_id).cloned()?; + let writer = inner.writers.get(&writer_id).cloned()?; Some(ConnWriter { writer_id, writer }) } pub async fn writer_lost(&self, writer_id: u64) -> Vec { - self.writers.write().await.remove(&writer_id); - let conns = self.conns_for_writer.write().await.remove(&writer_id).unwrap_or_default(); + let mut inner = self.inner.write().await; + inner.writers.remove(&writer_id); + let conns = inner + .conns_for_writer + .remove(&writer_id) + .unwrap_or_default() + .into_iter() + .collect::>(); let mut out = Vec::new(); - let mut writer_for_conn = self.writer_for_conn.write().await; - let meta = self.meta.read().await; - for conn_id in conns { - writer_for_conn.remove(&conn_id); - if let Some(m) = meta.get(&conn_id) { + inner.writer_for_conn.remove(&conn_id); + if let Some(m) = inner.meta.get(&conn_id) { out.push(BoundConn { conn_id, meta: m.clone(), @@ -127,7 +141,16 @@ impl ConnRegistry { } pub async fn get_meta(&self, conn_id: u64) -> Option { - let guard = self.meta.read().await; - guard.get(&conn_id).cloned() + let inner = self.inner.read().await; + inner.meta.get(&conn_id).cloned() + } + + pub async fn is_writer_empty(&self, writer_id: u64) -> bool { + let inner = self.inner.read().await; + inner + .conns_for_writer + .get(&writer_id) + .map(|s| s.is_empty()) + .unwrap_or(true) } } diff --git a/src/transport/middle_proxy/rotation.rs b/src/transport/middle_proxy/rotation.rs index 5457f70..6d94f3e 100644 --- a/src/transport/middle_proxy/rotation.rs +++ b/src/transport/middle_proxy/rotation.rs @@ -31,8 +31,8 @@ pub async fn me_rotation_task(pool: Arc, rng: Arc, interva info!(addr = %w.addr, writer_id = w.id, "Rotating ME connection"); match pool.connect_one(w.addr, rng.as_ref()).await { Ok(()) => { - // Remove old writer after new one is up. - pool.remove_writer_and_reroute(w.id).await; + // Mark old writer for graceful drain; removal happens when sessions finish. + pool.mark_writer_draining(w.id).await; } Err(e) => { warn!(addr = %w.addr, writer_id = w.id, error = %e, "ME rotation connect failed"); diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index 5eaacf0..2b0c42e 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -55,7 +55,7 @@ impl MePool { Ok(()) => return Ok(()), Err(e) => { warn!(error = %e, writer_id = current.writer_id, "ME write failed"); - self.remove_writer_and_reroute(current.writer_id).await; + self.remove_writer_and_close_clients(current.writer_id).await; continue; } } @@ -76,22 +76,29 @@ impl MePool { return Err(ProxyError::Proxy("No ME writers available for target DC".into())); } emergency_attempts += 1; - let map = self.proxy_map_v4.read().await; - if let Some(addrs) = map.get(&(target_dc as i32)) { - let mut shuffled = addrs.clone(); - shuffled.shuffle(&mut rand::rng()); - drop(map); - for (ip, port) in shuffled { - let addr = SocketAddr::new(ip, port); - if self.connect_one(addr, self.rng.as_ref()).await.is_ok() { - break; + for family in self.family_order() { + let map_guard = match family { + IpFamily::V4 => self.proxy_map_v4.read().await, + IpFamily::V6 => self.proxy_map_v6.read().await, + }; + if let Some(addrs) = map_guard.get(&(target_dc as i32)) { + let mut shuffled = addrs.clone(); + shuffled.shuffle(&mut rand::rng()); + drop(map_guard); + for (ip, port) in shuffled { + let addr = SocketAddr::new(ip, port); + if self.connect_one(addr, self.rng.as_ref()).await.is_ok() { + break; + } } + tokio::time::sleep(Duration::from_millis(100 * emergency_attempts)).await; + let ws2 = self.writers.read().await; + writers_snapshot = ws2.clone(); + drop(ws2); + candidate_indices = self.candidate_indices_for_dc(&writers_snapshot, target_dc).await; + break; } - tokio::time::sleep(Duration::from_millis(100 * emergency_attempts)).await; - let ws2 = self.writers.read().await; - writers_snapshot = ws2.clone(); - drop(ws2); - candidate_indices = self.candidate_indices_for_dc(&writers_snapshot, target_dc).await; + drop(map_guard); } if candidate_indices.is_empty() { return Err(ProxyError::Proxy("No ME writers available for target DC".into())); @@ -99,11 +106,10 @@ impl MePool { } candidate_indices.sort_by_key(|idx| { - writers_snapshot[*idx] - .degraded - .load(Ordering::Relaxed) - .then_some(1usize) - .unwrap_or(0) + let w = &writers_snapshot[*idx]; + let degraded = w.degraded.load(Ordering::Relaxed); + let draining = w.draining.load(Ordering::Relaxed); + (draining as usize, degraded as usize) }); let start = self.rr.fetch_add(1, Ordering::Relaxed) as usize % candidate_indices.len(); @@ -111,6 +117,9 @@ impl MePool { for offset in 0..candidate_indices.len() { let idx = candidate_indices[(start + offset) % candidate_indices.len()]; let w = &writers_snapshot[idx]; + if w.draining.load(Ordering::Relaxed) { + continue; + } if let Ok(mut guard) = w.writer.try_lock() { let send_res = guard.send(&payload).await; drop(guard); @@ -123,7 +132,7 @@ impl MePool { } Err(e) => { warn!(error = %e, writer_id = w.id, "ME write failed"); - self.remove_writer_and_reroute(w.id).await; + self.remove_writer_and_close_clients(w.id).await; continue; } } @@ -131,6 +140,9 @@ impl MePool { } let w = writers_snapshot[candidate_indices[start]].clone(); + if w.draining.load(Ordering::Relaxed) { + continue; + } match w.writer.lock().await.send(&payload).await { Ok(()) => { self.registry @@ -140,7 +152,7 @@ impl MePool { } Err(e) => { warn!(error = %e, writer_id = w.id, "ME write failed (blocking)"); - self.remove_writer_and_reroute(w.id).await; + self.remove_writer_and_close_clients(w.id).await; } } } @@ -151,9 +163,9 @@ impl MePool { let mut p = Vec::with_capacity(12); p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes()); p.extend_from_slice(&conn_id.to_le_bytes()); - if let Err(e) = w.writer.lock().await.send(&p).await { + if let Err(e) = w.writer.lock().await.send_and_flush(&p).await { debug!(error = %e, "ME close write failed"); - self.remove_writer_and_reroute(w.writer_id).await; + self.remove_writer_and_close_clients(w.writer_id).await; } } else { debug!(conn_id, "ME close skipped (writer missing)"); @@ -213,17 +225,24 @@ impl MePool { } if preferred.is_empty() { - return (0..writers.len()).collect(); + return (0..writers.len()) + .filter(|i| !writers[*i].draining.load(Ordering::Relaxed)) + .collect(); } let mut out = Vec::new(); for (idx, w) in writers.iter().enumerate() { + if w.draining.load(Ordering::Relaxed) { + continue; + } if preferred.iter().any(|p| *p == w.addr) { out.push(idx); } } if out.is_empty() { - return (0..writers.len()).collect(); + return (0..writers.len()) + .filter(|i| !writers[*i].draining.load(Ordering::Relaxed)) + .collect(); } out }