Merge pull request #133 from telemt/flow

New [network] section + ME Fixes + small bugs coverage
This commit is contained in:
Alexey
2026-02-18 06:09:36 +03:00
committed by GitHub
19 changed files with 980 additions and 442 deletions

View File

@@ -210,11 +210,18 @@ then Ctrl+X -> Y -> Enter to save
```toml
# === General Settings ===
[general]
# prefer_ipv6 is deprecated; use [network].prefer
prefer_ipv6 = false
fast_mode = true
use_middle_proxy = false
# ad_tag = "..."
[network]
ipv4 = true
ipv6 = true # set false to disable, omit for auto
prefer = 4 # 4 or 6
multipath = false
[general.modes]
classic = false
secure = false

View File

@@ -1,10 +1,19 @@
# === General Settings ===
[general]
prefer_ipv6 = true
# prefer_ipv6 is deprecated; use [network].prefer instead
prefer_ipv6 = false
fast_mode = true
use_middle_proxy = true
#ad_tag = "00000000000000000000000000000000"
[network]
# Enable/disable families; ipv6 = true/false/auto(None)
ipv4 = true
ipv6 = true
# prefer = 4 or 6
prefer = 4
multipath = false
# Log level: debug | verbose | normal | silent
# Can be overridden with --silent or --log-level CLI flags
# RUST_LOG env var takes absolute priority over all of these

View File

@@ -189,11 +189,18 @@ r#"# Telemt MTProxy — auto-generated config
show_link = ["{username}"]
[general]
# prefer_ipv6 is deprecated; use [network].prefer
prefer_ipv6 = false
fast_mode = true
use_middle_proxy = false
log_level = "normal"
[network]
ipv4 = true
ipv6 = true
prefer = 4
multipath = false
[general.modes]
classic = false
secure = false

View File

@@ -54,6 +54,10 @@ fn default_metrics_whitelist() -> Vec<IpAddr> {
vec!["127.0.0.1".parse().unwrap(), "::1".parse().unwrap()]
}
fn default_prefer_4() -> u8 {
4
}
fn default_unknown_dc_log_path() -> Option<String> {
Some("unknown-dc.txt".to_string())
}
@@ -185,6 +189,32 @@ impl std::fmt::Display for LogLevel {
}
}
fn validate_network_cfg(net: &mut NetworkConfig) -> Result<()> {
if !net.ipv4 && matches!(net.ipv6, Some(false)) {
return Err(ProxyError::Config(
"Both ipv4 and ipv6 are disabled in [network]".to_string(),
));
}
if net.prefer != 4 && net.prefer != 6 {
return Err(ProxyError::Config(
"network.prefer must be 4 or 6".to_string(),
));
}
if !net.ipv4 && net.prefer == 4 {
warn!("prefer=4 but ipv4=false; forcing prefer=6");
net.prefer = 6;
}
if matches!(net.ipv6, Some(false)) && net.prefer == 6 {
warn!("prefer=6 but ipv6=false; forcing prefer=4");
net.prefer = 4;
}
Ok(())
}
// ============= Sub-Configs =============
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -207,6 +237,34 @@ impl Default for ProxyModes {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkConfig {
#[serde(default = "default_true")]
pub ipv4: bool,
/// None = auto-detect IPv6 availability
#[serde(default)]
pub ipv6: Option<bool>,
/// 4 or 6
#[serde(default = "default_prefer_4")]
pub prefer: u8,
#[serde(default)]
pub multipath: bool,
}
impl Default for NetworkConfig {
fn default() -> Self {
Self {
ipv4: true,
ipv6: None,
prefer: 4,
multipath: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneralConfig {
#[serde(default)]
@@ -609,6 +667,9 @@ pub struct ProxyConfig {
#[serde(default)]
pub general: GeneralConfig,
#[serde(default)]
pub network: NetworkConfig,
#[serde(default)]
pub server: ServerConfig,
@@ -697,6 +758,16 @@ impl ProxyConfig {
config.censorship.mask_host = Some(config.censorship.tls_domain.clone());
}
// Migration: prefer_ipv6 -> network.prefer
if config.general.prefer_ipv6 {
if config.network.prefer == 4 {
config.network.prefer = 6;
}
warn!("prefer_ipv6 is deprecated, use [network].prefer = 6");
}
validate_network_cfg(&mut config.network)?;
// Random fake_cert_len
use rand::Rng;
config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096);

View File

@@ -16,6 +16,7 @@ mod config;
mod crypto;
mod error;
mod ip_tracker;
mod network;
mod metrics;
mod protocol;
mod proxy;
@@ -27,16 +28,14 @@ mod util;
use crate::config::{LogLevel, ProxyConfig};
use crate::crypto::SecureRandom;
use crate::ip_tracker::UserIpTracker;
use crate::network::probe::{decide_network_capabilities, log_probe_result, run_probe};
use crate::proxy::ClientHandler;
use crate::stats::{ReplayChecker, Stats};
use crate::stream::BufferPool;
use crate::transport::middle_proxy::{
MePool, fetch_proxy_config, run_me_ping, MePingFamily, MePingSample, format_sample_line,
stun_probe,
};
use crate::transport::{ListenOptions, UpstreamManager, create_listener};
use crate::util::ip::detect_ip;
use crate::protocol::constants::{TG_MIDDLE_PROXIES_V4, TG_MIDDLE_PROXIES_V6};
fn parse_cli() -> (String, bool, Option<String>) {
let mut config_path = "config.toml".to_string();
@@ -219,8 +218,17 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
warn!("Using default tls_domain. Consider setting a custom domain.");
}
let prefer_ipv6 = config.general.prefer_ipv6;
let mut use_middle_proxy = config.general.use_middle_proxy;
let probe = run_probe(
&config.network,
config.general.middle_proxy_nat_stun.clone(),
config.general.middle_proxy_nat_probe,
)
.await?;
let decision = decide_network_capabilities(&config.network, &probe);
log_probe_result(&probe, &decision);
let prefer_ipv6 = decision.prefer_ipv6();
let mut use_middle_proxy = config.general.use_middle_proxy && (decision.ipv4_me || decision.ipv6_me);
let config = Arc::new(config);
let stats = Arc::new(Stats::new());
let rng = Arc::new(SecureRandom::new());
@@ -244,39 +252,9 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
// Connection concurrency limit
let _max_connections = Arc::new(Semaphore::new(10_000));
// STUN check before choosing transport
if use_middle_proxy {
match stun_probe(config.general.middle_proxy_nat_stun.clone()).await {
Ok(Some(probe)) => {
info!(
local_ip = %probe.local_addr.ip(),
reflected_ip = %probe.reflected_addr.ip(),
"STUN Autodetect:"
);
if probe.local_addr.ip() != probe.reflected_addr.ip()
&& !config.general.stun_iface_mismatch_ignore
{
match crate::transport::middle_proxy::detect_public_ip().await {
Some(ip) => {
info!(
local_ip = %probe.local_addr.ip(),
reflected_ip = %probe.reflected_addr.ip(),
public_ip = %ip,
"STUN mismatch but public IP auto-detected, continuing with middle proxy"
);
}
None => {
warn!(
"STUN/IP-on-Interface mismatch and public IP auto-detect failed -> fallback to direct-DC"
);
use_middle_proxy = false;
}
}
}
}
Ok(None) => warn!("STUN probe returned no address; continuing"),
Err(e) => warn!(error = %e, "STUN probe failed; continuing"),
}
if use_middle_proxy && !decision.ipv4_me && !decision.ipv6_me {
warn!("No usable IP family for Middle Proxy detected; falling back to direct DC");
use_middle_proxy = false;
}
// =====================================================================
@@ -351,6 +329,8 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
cfg_v4.map.clone(),
cfg_v6.map.clone(),
cfg_v4.default_dc.or(cfg_v6.default_dc),
decision.clone(),
rng.clone(),
);
match pool.init(2, &rng).await {
@@ -482,7 +462,12 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
info!("================= Telegram DC Connectivity =================");
let ping_results = upstream_manager
.ping_all_dcs(prefer_ipv6, &config.dc_overrides)
.ping_all_dcs(
prefer_ipv6,
&config.dc_overrides,
decision.ipv4_dc,
decision.ipv6_dc,
)
.await;
for upstream_result in &ping_results {
@@ -559,8 +544,15 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
// Background tasks
let um_clone = upstream_manager.clone();
let decision_clone = decision.clone();
tokio::spawn(async move {
um_clone.run_health_checks(prefer_ipv6).await;
um_clone
.run_health_checks(
prefer_ipv6,
decision_clone.ipv4_dc,
decision_clone.ipv6_dc,
)
.await;
});
let rc_clone = replay_checker.clone();
@@ -568,16 +560,31 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
rc_clone.run_periodic_cleanup().await;
});
let detected_ip = detect_ip().await;
let detected_ip_v4: Option<std::net::IpAddr> = probe
.reflected_ipv4
.map(|s| s.ip())
.or_else(|| probe.detected_ipv4.map(std::net::IpAddr::V4));
let detected_ip_v6: Option<std::net::IpAddr> = probe
.reflected_ipv6
.map(|s| s.ip())
.or_else(|| probe.detected_ipv6.map(std::net::IpAddr::V6));
debug!(
"Detected IPs: v4={:?} v6={:?}",
detected_ip.ipv4, detected_ip.ipv6
detected_ip_v4, detected_ip_v6
);
let mut listeners = Vec::new();
for listener_conf in &config.server.listeners {
let addr = SocketAddr::new(listener_conf.ip, config.server.port);
if addr.is_ipv4() && !decision.ipv4_dc {
warn!(%addr, "Skipping IPv4 listener: IPv4 disabled by [network]");
continue;
}
if addr.is_ipv6() && !decision.ipv6_dc {
warn!(%addr, "Skipping IPv6 listener: IPv6 disabled by [network]");
continue;
}
let options = ListenOptions {
ipv6_only: listener_conf.ip.is_ipv6(),
..Default::default()
@@ -594,11 +601,11 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
} else if listener_conf.ip.is_unspecified() {
// Auto-detect for unspecified addresses
if listener_conf.ip.is_ipv4() {
detected_ip.ipv4
detected_ip_v4
.map(|ip| ip.to_string())
.unwrap_or_else(|| listener_conf.ip.to_string())
} else {
detected_ip.ipv6
detected_ip_v6
.map(|ip| ip.to_string())
.unwrap_or_else(|| listener_conf.ip.to_string())
}
@@ -626,9 +633,8 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai
let (host, port) = if let Some(ref h) = config.general.links.public_host {
(h.clone(), config.general.links.public_port.unwrap_or(config.server.port))
} else {
let ip = detected_ip
.ipv4
.or(detected_ip.ipv6)
let ip = detected_ip_v4
.or(detected_ip_v6)
.map(|ip| ip.to_string());
if ip.is_none() {
warn!("show_link is configured but public IP could not be detected. Set public_host in config.");

4
src/network/mod.rs Normal file
View File

@@ -0,0 +1,4 @@
pub mod probe;
pub mod stun;
pub use stun::IpFamily;

225
src/network/probe.rs Normal file
View File

@@ -0,0 +1,225 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket};
use tracing::{info, warn};
use crate::config::NetworkConfig;
use crate::error::Result;
use crate::network::stun::{stun_probe_dual, DualStunResult, IpFamily};
#[derive(Debug, Clone, Default)]
pub struct NetworkProbe {
pub detected_ipv4: Option<Ipv4Addr>,
pub detected_ipv6: Option<Ipv6Addr>,
pub reflected_ipv4: Option<SocketAddr>,
pub reflected_ipv6: Option<SocketAddr>,
pub ipv4_is_bogon: bool,
pub ipv6_is_bogon: bool,
pub ipv4_nat_detected: bool,
pub ipv6_nat_detected: bool,
pub ipv4_usable: bool,
pub ipv6_usable: bool,
}
#[derive(Debug, Clone, Default)]
pub struct NetworkDecision {
pub ipv4_dc: bool,
pub ipv6_dc: bool,
pub ipv4_me: bool,
pub ipv6_me: bool,
pub effective_prefer: u8,
pub effective_multipath: bool,
}
impl NetworkDecision {
pub fn prefer_ipv6(&self) -> bool {
self.effective_prefer == 6
}
pub fn me_families(&self) -> Vec<IpFamily> {
let mut res = Vec::new();
if self.ipv4_me {
res.push(IpFamily::V4);
}
if self.ipv6_me {
res.push(IpFamily::V6);
}
res
}
}
pub async fn run_probe(config: &NetworkConfig, stun_addr: Option<String>, nat_probe: bool) -> Result<NetworkProbe> {
let mut probe = NetworkProbe::default();
probe.detected_ipv4 = detect_local_ip_v4();
probe.detected_ipv6 = detect_local_ip_v6();
probe.ipv4_is_bogon = probe.detected_ipv4.map(is_bogon_v4).unwrap_or(false);
probe.ipv6_is_bogon = probe.detected_ipv6.map(is_bogon_v6).unwrap_or(false);
let stun_server = stun_addr.unwrap_or_else(|| "stun.l.google.com:19302".to_string());
let stun_res = if nat_probe {
stun_probe_dual(&stun_server).await?
} else {
DualStunResult::default()
};
probe.reflected_ipv4 = stun_res.v4.map(|r| r.reflected_addr);
probe.reflected_ipv6 = stun_res.v6.map(|r| r.reflected_addr);
probe.ipv4_nat_detected = match (probe.detected_ipv4, probe.reflected_ipv4) {
(Some(det), Some(reflected)) => det != reflected.ip(),
_ => false,
};
probe.ipv6_nat_detected = match (probe.detected_ipv6, probe.reflected_ipv6) {
(Some(det), Some(reflected)) => det != reflected.ip(),
_ => false,
};
probe.ipv4_usable = config.ipv4
&& probe.detected_ipv4.is_some()
&& (!probe.ipv4_is_bogon || probe.reflected_ipv4.map(|r| !is_bogon(r.ip())).unwrap_or(false));
let ipv6_enabled = config.ipv6.unwrap_or(probe.detected_ipv6.is_some());
probe.ipv6_usable = ipv6_enabled
&& probe.detected_ipv6.is_some()
&& (!probe.ipv6_is_bogon || probe.reflected_ipv6.map(|r| !is_bogon(r.ip())).unwrap_or(false));
Ok(probe)
}
pub fn decide_network_capabilities(config: &NetworkConfig, probe: &NetworkProbe) -> NetworkDecision {
let mut decision = NetworkDecision::default();
decision.ipv4_dc = config.ipv4 && probe.detected_ipv4.is_some();
decision.ipv6_dc = config.ipv6.unwrap_or(probe.detected_ipv6.is_some()) && probe.detected_ipv6.is_some();
decision.ipv4_me = config.ipv4
&& probe.detected_ipv4.is_some()
&& (!probe.ipv4_is_bogon || probe.reflected_ipv4.is_some());
let ipv6_enabled = config.ipv6.unwrap_or(probe.detected_ipv6.is_some());
decision.ipv6_me = ipv6_enabled
&& probe.detected_ipv6.is_some()
&& (!probe.ipv6_is_bogon || probe.reflected_ipv6.is_some());
decision.effective_prefer = match config.prefer {
6 if decision.ipv6_me || decision.ipv6_dc => 6,
4 if decision.ipv4_me || decision.ipv4_dc => 4,
6 => {
warn!("prefer=6 requested but IPv6 unavailable; falling back to IPv4");
4
}
_ => 4,
};
let me_families = decision.ipv4_me as u8 + decision.ipv6_me as u8;
decision.effective_multipath = config.multipath && me_families >= 2;
decision
}
fn detect_local_ip_v4() -> Option<Ipv4Addr> {
let socket = UdpSocket::bind("0.0.0.0:0").ok()?;
socket.connect("8.8.8.8:80").ok()?;
match socket.local_addr().ok()?.ip() {
IpAddr::V4(v4) => Some(v4),
_ => None,
}
}
fn detect_local_ip_v6() -> Option<Ipv6Addr> {
let socket = UdpSocket::bind("[::]:0").ok()?;
socket.connect("[2001:4860:4860::8888]:80").ok()?;
match socket.local_addr().ok()?.ip() {
IpAddr::V6(v6) => Some(v6),
_ => None,
}
}
pub fn is_bogon(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => is_bogon_v4(v4),
IpAddr::V6(v6) => is_bogon_v6(v6),
}
}
pub fn is_bogon_v4(ip: Ipv4Addr) -> bool {
let octets = ip.octets();
if ip.is_private() || ip.is_loopback() || ip.is_link_local() {
return true;
}
if octets[0] == 0 {
return true;
}
if octets[0] == 100 && (octets[1] & 0xC0) == 64 {
return true;
}
if octets[0] == 192 && octets[1] == 0 && octets[2] == 0 {
return true;
}
if octets[0] == 192 && octets[1] == 0 && octets[2] == 2 {
return true;
}
if octets[0] == 198 && (octets[1] & 0xFE) == 18 {
return true;
}
if octets[0] == 198 && octets[1] == 51 && octets[2] == 100 {
return true;
}
if octets[0] == 203 && octets[1] == 0 && octets[2] == 113 {
return true;
}
if ip.is_multicast() {
return true;
}
if octets[0] >= 240 {
return true;
}
if ip.is_broadcast() {
return true;
}
false
}
pub fn is_bogon_v6(ip: Ipv6Addr) -> bool {
if ip.is_unspecified() || ip.is_loopback() || ip.is_unique_local() {
return true;
}
let segs = ip.segments();
if (segs[0] & 0xFFC0) == 0xFE80 {
return true;
}
if segs[0..5] == [0, 0, 0, 0, 0] && segs[5] == 0xFFFF {
return true;
}
if segs[0] == 0x0100 && segs[1..4] == [0, 0, 0] {
return true;
}
if segs[0] == 0x2001 && segs[1] == 0x0db8 {
return true;
}
if segs[0] == 0x2002 {
return true;
}
if ip.is_multicast() {
return true;
}
false
}
pub fn log_probe_result(probe: &NetworkProbe, decision: &NetworkDecision) {
info!(
ipv4 = probe.detected_ipv4.as_ref().map(|v| v.to_string()).unwrap_or_else(|| "-".into()),
ipv6 = probe.detected_ipv6.as_ref().map(|v| v.to_string()).unwrap_or_else(|| "-".into()),
reflected_v4 = probe.reflected_ipv4.as_ref().map(|v| v.ip().to_string()).unwrap_or_else(|| "-".into()),
reflected_v6 = probe.reflected_ipv6.as_ref().map(|v| v.ip().to_string()).unwrap_or_else(|| "-".into()),
ipv4_bogon = probe.ipv4_is_bogon,
ipv6_bogon = probe.ipv6_is_bogon,
ipv4_me = decision.ipv4_me,
ipv6_me = decision.ipv6_me,
ipv4_dc = decision.ipv4_dc,
ipv6_dc = decision.ipv6_dc,
prefer = decision.effective_prefer,
multipath = decision.effective_multipath,
"Network capabilities resolved"
);
}

186
src/network/stun.rs Normal file
View File

@@ -0,0 +1,186 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use tokio::net::{lookup_host, UdpSocket};
use crate::error::{ProxyError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum IpFamily {
V4,
V6,
}
#[derive(Debug, Clone, Copy)]
pub struct StunProbeResult {
pub local_addr: SocketAddr,
pub reflected_addr: SocketAddr,
pub family: IpFamily,
}
#[derive(Debug, Default, Clone)]
pub struct DualStunResult {
pub v4: Option<StunProbeResult>,
pub v6: Option<StunProbeResult>,
}
pub async fn stun_probe_dual(stun_addr: &str) -> Result<DualStunResult> {
let (v4, v6) = tokio::join!(
stun_probe_family(stun_addr, IpFamily::V4),
stun_probe_family(stun_addr, IpFamily::V6),
);
Ok(DualStunResult {
v4: v4?,
v6: v6?,
})
}
pub async fn stun_probe_family(stun_addr: &str, family: IpFamily) -> Result<Option<StunProbeResult>> {
use rand::RngCore;
let bind_addr = match family {
IpFamily::V4 => "0.0.0.0:0",
IpFamily::V6 => "[::]:0",
};
let socket = UdpSocket::bind(bind_addr)
.await
.map_err(|e| ProxyError::Proxy(format!("STUN bind failed: {e}")))?;
let target_addr = resolve_stun_addr(stun_addr, family).await?;
if let Some(addr) = target_addr {
socket
.connect(addr)
.await
.map_err(|e| ProxyError::Proxy(format!("STUN connect failed: {e}")))?;
} else {
return Ok(None);
}
let mut req = [0u8; 20];
req[0..2].copy_from_slice(&0x0001u16.to_be_bytes()); // Binding Request
req[2..4].copy_from_slice(&0u16.to_be_bytes()); // length
req[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes()); // magic cookie
rand::rng().fill_bytes(&mut req[8..20]); // transaction ID
socket
.send(&req)
.await
.map_err(|e| ProxyError::Proxy(format!("STUN send failed: {e}")))?;
let mut buf = [0u8; 256];
let n = socket
.recv(&mut buf)
.await
.map_err(|e| ProxyError::Proxy(format!("STUN recv failed: {e}")))?;
if n < 20 {
return Ok(None);
}
let magic = 0x2112A442u32.to_be_bytes();
let txid = &req[8..20];
let mut idx = 20;
while idx + 4 <= n {
let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().unwrap());
let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().unwrap()) as usize;
idx += 4;
if idx + alen > n {
break;
}
match atype {
0x0020 /* XOR-MAPPED-ADDRESS */ | 0x0001 /* MAPPED-ADDRESS */ => {
if alen < 8 {
break;
}
let family_byte = buf[idx + 1];
let port_bytes = [buf[idx + 2], buf[idx + 3]];
let len_check = match family_byte {
0x01 => 4,
0x02 => 16,
_ => 0,
};
if len_check == 0 || alen < 4 + len_check {
break;
}
let raw_ip = &buf[idx + 4..idx + 4 + len_check];
let mut port = u16::from_be_bytes(port_bytes);
let reflected_ip = if atype == 0x0020 {
port ^= ((magic[0] as u16) << 8) | magic[1] as u16;
match family_byte {
0x01 => {
let ip = [
raw_ip[0] ^ magic[0],
raw_ip[1] ^ magic[1],
raw_ip[2] ^ magic[2],
raw_ip[3] ^ magic[3],
];
IpAddr::V4(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3]))
}
0x02 => {
let mut ip = [0u8; 16];
let xor_key = [magic.as_slice(), txid].concat();
for (i, b) in raw_ip.iter().enumerate().take(16) {
ip[i] = *b ^ xor_key[i];
}
IpAddr::V6(Ipv6Addr::from(ip))
}
_ => {
idx += (alen + 3) & !3;
continue;
}
}
} else {
match family_byte {
0x01 => IpAddr::V4(Ipv4Addr::new(raw_ip[0], raw_ip[1], raw_ip[2], raw_ip[3])),
0x02 => IpAddr::V6(Ipv6Addr::from(<[u8; 16]>::try_from(raw_ip).unwrap())),
_ => {
idx += (alen + 3) & !3;
continue;
}
}
};
let reflected_addr = SocketAddr::new(reflected_ip, port);
let local_addr = socket
.local_addr()
.map_err(|e| ProxyError::Proxy(format!("STUN local_addr failed: {e}")))?;
return Ok(Some(StunProbeResult {
local_addr,
reflected_addr,
family,
}));
}
_ => {}
}
idx += (alen + 3) & !3;
}
Ok(None)
}
async fn resolve_stun_addr(stun_addr: &str, family: IpFamily) -> Result<Option<SocketAddr>> {
if let Ok(addr) = stun_addr.parse::<SocketAddr>() {
return Ok(match (addr.is_ipv4(), family) {
(true, IpFamily::V4) | (false, IpFamily::V6) => Some(addr),
_ => None,
});
}
let addrs = lookup_host(stun_addr)
.await
.map_err(|e| ProxyError::Proxy(format!("STUN resolve failed: {e}")))?;
let target = addrs
.filter(|a| match (a.is_ipv4(), family) {
(true, IpFamily::V4) => true,
(false, IpFamily::V6) => true,
_ => false,
})
.next();
Ok(target)
}

View File

@@ -80,7 +80,8 @@ where
}
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
let datacenters = if config.general.prefer_ipv6 {
let prefer_v6 = config.network.prefer == 6 && config.network.ipv6.unwrap_or(true);
let datacenters = if prefer_v6 {
&*TG_DATACENTERS_V6
} else {
&*TG_DATACENTERS_V4
@@ -90,7 +91,6 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
let dc_key = dc_idx.to_string();
if let Some(addrs) = config.dc_overrides.get(&dc_key) {
let prefer_v6 = config.general.prefer_ipv6;
let mut parsed = Vec::new();
for addr_str in addrs {
match addr_str.parse::<SocketAddr>() {

View File

@@ -16,6 +16,7 @@ use tracing::{debug, info, warn};
use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_keys, sha256};
use crate::error::{ProxyError, Result};
use crate::network::IpFamily;
use crate::protocol::constants::{
ME_CONNECT_TIMEOUT_SECS, ME_HANDSHAKE_TIMEOUT_SECS, RPC_CRYPTO_AES_U32, RPC_HANDSHAKE_ERROR_U32,
RPC_HANDSHAKE_U32, RPC_PING_U32, RPC_PONG_U32, RPC_NONCE_U32,
@@ -101,8 +102,13 @@ impl MePool {
let peer_addr = stream.peer_addr().map_err(ProxyError::Io)?;
let _ = self.maybe_detect_nat_ip(local_addr.ip()).await;
let family = if local_addr.ip().is_ipv4() {
IpFamily::V4
} else {
IpFamily::V6
};
let reflected = if self.nat_probe {
self.maybe_reflect_public_addr().await
self.maybe_reflect_public_addr(family).await
} else {
None
};

View File

@@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
@@ -7,107 +7,84 @@ use tracing::{debug, info, warn};
use rand::seq::SliceRandom;
use crate::crypto::SecureRandom;
use crate::network::IpFamily;
use super::MePool;
pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, _min_connections: usize) {
let mut backoff: HashMap<i32, u64> = HashMap::new();
let mut last_attempt: HashMap<i32, Instant> = HashMap::new();
let mut backoff: HashMap<(i32, IpFamily), u64> = HashMap::new();
let mut last_attempt: HashMap<(i32, IpFamily), Instant> = HashMap::new();
loop {
tokio::time::sleep(Duration::from_secs(30)).await;
// Per-DC coverage check
let map = pool.proxy_map_v4.read().await.clone();
let writer_addrs: std::collections::HashSet<SocketAddr> = pool
.writers
.read()
.await
.iter()
.map(|w| w.addr)
.collect();
check_family(IpFamily::V4, &pool, &rng, &mut backoff, &mut last_attempt).await;
check_family(IpFamily::V6, &pool, &rng, &mut backoff, &mut last_attempt).await;
}
}
for (dc, addrs) in map.iter() {
let dc_addrs: Vec<SocketAddr> = addrs
.iter()
.map(|(ip, port)| SocketAddr::new(*ip, *port))
.collect();
let has_coverage = dc_addrs.iter().any(|a| writer_addrs.contains(a));
if !has_coverage {
let delay = *backoff.get(dc).unwrap_or(&30);
let now = Instant::now();
if let Some(last) = last_attempt.get(dc) {
if now.duration_since(*last).as_secs() < delay {
continue;
}
}
warn!(dc = %dc, delay, "DC has no ME coverage, reconnecting...");
let mut shuffled = dc_addrs.clone();
shuffled.shuffle(&mut rand::rng());
let mut reconnected = false;
for addr in shuffled {
match pool.connect_one(addr, &rng).await {
Ok(()) => {
info!(%addr, dc = %dc, "ME reconnected for DC coverage");
backoff.insert(*dc, 30);
last_attempt.insert(*dc, now);
reconnected = true;
break;
}
Err(e) => debug!(%addr, dc = %dc, error = %e, "ME reconnect failed"),
}
}
if !reconnected {
let next = (*backoff.get(dc).unwrap_or(&30)).saturating_mul(2).min(300);
backoff.insert(*dc, next);
last_attempt.insert(*dc, now);
}
async fn check_family(
family: IpFamily,
pool: &Arc<MePool>,
rng: &Arc<SecureRandom>,
backoff: &mut HashMap<(i32, IpFamily), u64>,
last_attempt: &mut HashMap<(i32, IpFamily), Instant>,
) {
let enabled = match family {
IpFamily::V4 => pool.decision.ipv4_me,
IpFamily::V6 => pool.decision.ipv6_me,
};
if !enabled {
return;
}
let map = match family {
IpFamily::V4 => pool.proxy_map_v4.read().await.clone(),
IpFamily::V6 => pool.proxy_map_v6.read().await.clone(),
};
let writer_addrs: HashSet<SocketAddr> = pool
.writers
.read()
.await
.iter()
.map(|w| w.addr)
.collect();
for (dc, addrs) in map.iter() {
let dc_addrs: Vec<SocketAddr> = addrs
.iter()
.map(|(ip, port)| SocketAddr::new(*ip, *port))
.collect();
let has_coverage = dc_addrs.iter().any(|a| writer_addrs.contains(a));
if has_coverage {
continue;
}
let key = (*dc, family);
let delay = *backoff.get(&key).unwrap_or(&30);
let now = Instant::now();
if let Some(last) = last_attempt.get(&key) {
if now.duration_since(*last).as_secs() < delay {
continue;
}
}
// IPv6 coverage check (if available)
let map_v6 = pool.proxy_map_v6.read().await.clone();
let writer_addrs_v6: std::collections::HashSet<SocketAddr> = pool
.writers
.read()
.await
.iter()
.map(|w| w.addr)
.collect();
for (dc, addrs) in map_v6.iter() {
let dc_addrs: Vec<SocketAddr> = addrs
.iter()
.map(|(ip, port)| SocketAddr::new(*ip, *port))
.collect();
let has_coverage = dc_addrs.iter().any(|a| writer_addrs_v6.contains(a));
if !has_coverage {
let delay = *backoff.get(dc).unwrap_or(&30);
let now = Instant::now();
if let Some(last) = last_attempt.get(dc) {
if now.duration_since(*last).as_secs() < delay {
continue;
}
}
warn!(dc = %dc, delay, "IPv6 DC has no ME coverage, reconnecting...");
let mut shuffled = dc_addrs.clone();
shuffled.shuffle(&mut rand::rng());
let mut reconnected = false;
for addr in shuffled {
match pool.connect_one(addr, &rng).await {
Ok(()) => {
info!(%addr, dc = %dc, "ME reconnected for IPv6 DC coverage");
backoff.insert(*dc, 30);
last_attempt.insert(*dc, now);
reconnected = true;
break;
}
Err(e) => debug!(%addr, dc = %dc, error = %e, "ME reconnect failed (IPv6)"),
}
}
if !reconnected {
let next = (*backoff.get(dc).unwrap_or(&30)).saturating_mul(2).min(300);
backoff.insert(*dc, next);
last_attempt.insert(*dc, now);
warn!(dc = %dc, delay, ?family, "DC has no ME coverage, reconnecting...");
let mut shuffled = dc_addrs.clone();
shuffled.shuffle(&mut rand::rng());
let mut reconnected = false;
for addr in shuffled {
match pool.connect_one(addr, rng.as_ref()).await {
Ok(()) => {
info!(%addr, dc = %dc, ?family, "ME reconnected for DC coverage");
backoff.insert(key, 30);
last_attempt.insert(key, now);
reconnected = true;
break;
}
Err(e) => debug!(%addr, dc = %dc, error = %e, ?family, "ME reconnect failed"),
}
}
if !reconnected {
let next = (*backoff.get(&key).unwrap_or(&30)).saturating_mul(2).min(300);
backoff.insert(key, next);
last_attempt.insert(key, now);
}
}
}

View File

@@ -19,7 +19,7 @@ use bytes::Bytes;
pub use health::me_health_monitor;
pub use ping::{run_me_ping, format_sample_line, MePingReport, MePingSample, MePingFamily};
pub use pool::MePool;
pub use pool_nat::{stun_probe, detect_public_ip, StunProbeResult};
pub use pool_nat::{stun_probe, detect_public_ip};
pub use registry::ConnRegistry;
pub use secret::fetch_proxy_secret;
pub use config_updater::{fetch_proxy_config, me_config_updater};

View File

@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
@@ -92,8 +93,16 @@ mod tests {
pub async fn run_me_ping(pool: &Arc<MePool>, rng: &SecureRandom) -> Vec<MePingReport> {
let mut reports = Vec::new();
let v4_map = pool.proxy_map_v4.read().await.clone();
let v6_map = pool.proxy_map_v6.read().await.clone();
let v4_map = if pool.decision.ipv4_me {
pool.proxy_map_v4.read().await.clone()
} else {
HashMap::new()
};
let v6_map = if pool.decision.ipv6_me {
pool.proxy_map_v6.read().await.clone()
} else {
HashMap::new()
};
let mut grouped: Vec<(MePingFamily, i32, Vec<(IpAddr, u16)>)> = Vec::new();
for (dc, addrs) in v4_map {

View File

@@ -12,6 +12,8 @@ use std::time::Duration;
use crate::crypto::SecureRandom;
use crate::error::{ProxyError, Result};
use crate::network::probe::NetworkDecision;
use crate::network::IpFamily;
use crate::protocol::constants::*;
use super::ConnRegistry;
@@ -36,6 +38,8 @@ pub struct MePool {
pub(super) registry: Arc<ConnRegistry>,
pub(super) writers: Arc<RwLock<Vec<MeWriter>>>,
pub(super) rr: AtomicU64,
pub(super) decision: NetworkDecision,
pub(super) rng: Arc<SecureRandom>,
pub(super) proxy_tag: Option<Vec<u8>>,
pub(super) proxy_secret: Arc<RwLock<Vec<u8>>>,
pub(super) nat_ip_cfg: Option<IpAddr>,
@@ -48,10 +52,16 @@ pub struct MePool {
pub(super) next_writer_id: AtomicU64,
pub(super) ping_tracker: Arc<Mutex<HashMap<i64, (std::time::Instant, u64)>>>,
pub(super) rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
pub(super) nat_reflection_cache: Arc<Mutex<Option<(std::time::Instant, std::net::SocketAddr)>>>,
pub(super) nat_reflection_cache: Arc<Mutex<NatReflectionCache>>,
pool_size: usize,
}
#[derive(Debug, Default)]
pub struct NatReflectionCache {
pub v4: Option<(std::time::Instant, std::net::SocketAddr)>,
pub v6: Option<(std::time::Instant, std::net::SocketAddr)>,
}
impl MePool {
pub fn new(
proxy_tag: Option<Vec<u8>>,
@@ -62,11 +72,15 @@ impl MePool {
proxy_map_v4: HashMap<i32, Vec<(IpAddr, u16)>>,
proxy_map_v6: HashMap<i32, Vec<(IpAddr, u16)>>,
default_dc: Option<i32>,
decision: NetworkDecision,
rng: Arc<SecureRandom>,
) -> Arc<Self> {
Arc::new(Self {
registry: Arc::new(ConnRegistry::new()),
writers: Arc::new(RwLock::new(Vec::new())),
rr: AtomicU64::new(0),
decision,
rng,
proxy_tag,
proxy_secret: Arc::new(RwLock::new(proxy_secret)),
nat_ip_cfg: nat_ip,
@@ -80,7 +94,7 @@ impl MePool {
next_writer_id: AtomicU64::new(1),
ping_tracker: Arc::new(Mutex::new(HashMap::new())),
rtt_stats: Arc::new(Mutex::new(HashMap::new())),
nat_reflection_cache: Arc::new(Mutex::new(None)),
nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())),
})
}
@@ -103,29 +117,30 @@ impl MePool {
pub async fn reconcile_connections(self: &Arc<Self>, rng: &SecureRandom) {
use std::collections::HashSet;
let map = self.proxy_map_v4.read().await.clone();
let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map
.iter()
.map(|(dc, addrs)| (*dc, addrs.clone()))
.collect();
let writers = self.writers.read().await;
let current: HashSet<SocketAddr> = writers.iter().map(|w| w.addr).collect();
drop(writers);
for (_dc, addrs) in map.iter() {
let dc_addrs: Vec<SocketAddr> = addrs
.iter()
.map(|(ip, port)| SocketAddr::new(*ip, *port))
.collect();
if !dc_addrs.iter().any(|a| current.contains(a)) {
let mut shuffled = dc_addrs.clone();
shuffled.shuffle(&mut rand::rng());
for addr in shuffled {
if self.connect_one(addr, rng).await.is_ok() {
break;
for family in self.family_order() {
let map = self.proxy_map_for_family(family).await;
for (_dc, addrs) in map.iter() {
let dc_addrs: Vec<SocketAddr> = addrs
.iter()
.map(|(ip, port)| SocketAddr::new(*ip, *port))
.collect();
if !dc_addrs.iter().any(|a| current.contains(a)) {
let mut shuffled = dc_addrs.clone();
shuffled.shuffle(&mut rand::rng());
for addr in shuffled {
if self.connect_one(addr, rng).await.is_ok() {
break;
}
}
}
}
if !self.decision.effective_multipath && !current.is_empty() {
break;
}
}
}
@@ -181,47 +196,82 @@ impl MePool {
}
}
pub(super) fn family_order(&self) -> Vec<IpFamily> {
let mut order = Vec::new();
if self.decision.prefer_ipv6() {
if self.decision.ipv6_me {
order.push(IpFamily::V6);
}
if self.decision.ipv4_me {
order.push(IpFamily::V4);
}
} else {
if self.decision.ipv4_me {
order.push(IpFamily::V4);
}
if self.decision.ipv6_me {
order.push(IpFamily::V6);
}
}
order
}
async fn proxy_map_for_family(&self, family: IpFamily) -> HashMap<i32, Vec<(IpAddr, u16)>> {
match family {
IpFamily::V4 => self.proxy_map_v4.read().await.clone(),
IpFamily::V6 => self.proxy_map_v6.read().await.clone(),
}
}
pub async fn init(self: &Arc<Self>, pool_size: usize, rng: &Arc<SecureRandom>) -> Result<()> {
let map = self.proxy_map_v4.read().await.clone();
let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map
.iter()
.map(|(dc, addrs)| (*dc, addrs.clone()))
.collect();
let family_order = self.family_order();
let ks = self.key_selector().await;
info!(
me_servers = map.len(),
me_servers = self.proxy_map_v4.read().await.len(),
pool_size,
key_selector = format_args!("0x{ks:08x}"),
secret_len = self.proxy_secret.read().await.len(),
"Initializing ME pool"
);
// Ensure at least one connection per DC; run DCs in parallel.
let mut join = tokio::task::JoinSet::new();
for (dc, addrs) in dc_addrs.iter().cloned() {
if addrs.is_empty() {
continue;
}
let pool = Arc::clone(self);
let rng_clone = Arc::clone(rng);
join.spawn(async move {
pool.connect_primary_for_dc(dc, addrs, rng_clone).await;
});
}
while let Some(_res) = join.join_next().await {}
for family in family_order {
let map = self.proxy_map_for_family(family).await;
let dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map
.iter()
.map(|(dc, addrs)| (*dc, addrs.clone()))
.collect();
// Additional connections up to pool_size total (round-robin across DCs)
for (dc, addrs) in dc_addrs.iter() {
for (ip, port) in addrs {
// Ensure at least one connection per DC; run DCs in parallel.
let mut join = tokio::task::JoinSet::new();
for (dc, addrs) in dc_addrs.iter().cloned() {
if addrs.is_empty() {
continue;
}
let pool = Arc::clone(self);
let rng_clone = Arc::clone(rng);
join.spawn(async move {
pool.connect_primary_for_dc(dc, addrs, rng_clone).await;
});
}
while let Some(_res) = join.join_next().await {}
// Additional connections up to pool_size total (round-robin across DCs)
for (dc, addrs) in dc_addrs.iter() {
for (ip, port) in addrs {
if self.connection_count() >= pool_size {
break;
}
let addr = SocketAddr::new(*ip, *port);
if let Err(e) = self.connect_one(addr, rng.as_ref()).await {
debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed");
}
}
if self.connection_count() >= pool_size {
break;
}
let addr = SocketAddr::new(*ip, *port);
if let Err(e) = self.connect_one(addr, rng.as_ref()).await {
debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed");
}
}
if self.connection_count() >= pool_size {
if !self.decision.effective_multipath && self.connection_count() > 0 {
break;
}
}
@@ -309,14 +359,15 @@ impl MePool {
}
_ = tokio::time::sleep(Duration::from_secs(wait)) => {}
}
let sent_id = ping_id;
let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_PING_U32.to_le_bytes());
p.extend_from_slice(&ping_id.to_le_bytes());
ping_id = ping_id.wrapping_add(1);
p.extend_from_slice(&sent_id.to_le_bytes());
{
let mut tracker = ping_tracker_ping.lock().await;
tracker.insert(ping_id, (std::time::Instant::now(), writer_id));
tracker.insert(sent_id, (std::time::Instant::now(), writer_id));
}
ping_id = ping_id.wrapping_add(1);
if let Err(e) = rpc_w_ping.lock().await.send(&p).await {
debug!(error = %e, "Active ME ping failed, removing dead writer");
cancel_ping.cancel();

View File

@@ -4,19 +4,14 @@ use std::time::Duration;
use tracing::{info, warn};
use crate::error::{ProxyError, Result};
use crate::network::probe::is_bogon;
use crate::network::stun::{stun_probe_dual, IpFamily, StunProbeResult};
use super::MePool;
use std::time::Instant;
#[derive(Debug, Clone, Copy)]
pub struct StunProbeResult {
pub local_addr: std::net::SocketAddr,
pub reflected_addr: std::net::SocketAddr,
}
pub async fn stun_probe(stun_addr: Option<String>) -> Result<Option<StunProbeResult>> {
pub async fn stun_probe(stun_addr: Option<String>) -> Result<crate::network::stun::DualStunResult> {
let stun_addr = stun_addr.unwrap_or_else(|| "stun.l.google.com:19302".to_string());
fetch_stun_binding(&stun_addr).await
stun_probe_dual(&stun_addr).await
}
pub async fn detect_public_ip() -> Option<IpAddr> {
@@ -35,7 +30,7 @@ impl MePool {
match (ip, nat_ip) {
(IpAddr::V4(src), IpAddr::V4(dst))
if is_privateish(IpAddr::V4(src))
if is_bogon(IpAddr::V4(src))
|| src.is_loopback()
|| src.is_unspecified() =>
{
@@ -55,7 +50,7 @@ impl MePool {
) -> std::net::SocketAddr {
let ip = if let Some(r) = reflected {
// Use reflected IP (not port) only when local address is non-public.
if is_privateish(addr.ip()) || addr.ip().is_loopback() || addr.ip().is_unspecified() {
if is_bogon(addr.ip()) || addr.ip().is_loopback() || addr.ip().is_unspecified() {
r.ip()
} else {
self.translate_ip_for_nat(addr.ip())
@@ -73,7 +68,7 @@ impl MePool {
return self.nat_ip_cfg;
}
if !(is_privateish(local_ip) || local_ip.is_loopback() || local_ip.is_unspecified()) {
if !(is_bogon(local_ip) || local_ip.is_loopback() || local_ip.is_unspecified()) {
return None;
}
@@ -98,12 +93,19 @@ impl MePool {
}
}
pub(super) async fn maybe_reflect_public_addr(&self) -> Option<std::net::SocketAddr> {
pub(super) async fn maybe_reflect_public_addr(
&self,
family: IpFamily,
) -> Option<std::net::SocketAddr> {
const STUN_CACHE_TTL: Duration = Duration::from_secs(600);
if let Ok(mut cache) = self.nat_reflection_cache.try_lock() {
if let Some((ts, addr)) = *cache {
let slot = match family {
IpFamily::V4 => &mut cache.v4,
IpFamily::V6 => &mut cache.v6,
};
if let Some((ts, addr)) = slot {
if ts.elapsed() < STUN_CACHE_TTL {
return Some(addr);
return Some(*addr);
}
}
}
@@ -112,12 +114,20 @@ impl MePool {
.nat_stun
.clone()
.unwrap_or_else(|| "stun.l.google.com:19302".to_string());
match fetch_stun_binding(&stun_addr).await {
Ok(sa) => {
if let Some(result) = sa {
info!(local = %result.local_addr, reflected = %result.reflected_addr, "NAT probe: reflected address");
match stun_probe_dual(&stun_addr).await {
Ok(res) => {
let picked: Option<StunProbeResult> = match family {
IpFamily::V4 => res.v4,
IpFamily::V6 => res.v6,
};
if let Some(result) = picked {
info!(local = %result.local_addr, reflected = %result.reflected_addr, family = ?family, "NAT probe: reflected address");
if let Ok(mut cache) = self.nat_reflection_cache.try_lock() {
*cache = Some((Instant::now(), result.reflected_addr));
let slot = match family {
IpFamily::V4 => &mut cache.v4,
IpFamily::V6 => &mut cache.v6,
};
*slot = Some((Instant::now(), result.reflected_addr));
}
Some(result.reflected_addr)
} else {
@@ -158,98 +168,3 @@ async fn fetch_public_ipv4_once(url: &str) -> Result<Option<Ipv4Addr>> {
let ip = text.trim().parse().ok();
Ok(ip)
}
async fn fetch_stun_binding(stun_addr: &str) -> Result<Option<StunProbeResult>> {
use rand::RngCore;
use tokio::net::UdpSocket;
let socket = UdpSocket::bind("0.0.0.0:0")
.await
.map_err(|e| ProxyError::Proxy(format!("STUN bind failed: {e}")))?;
socket
.connect(stun_addr)
.await
.map_err(|e| ProxyError::Proxy(format!("STUN connect failed: {e}")))?;
// Build minimal Binding Request.
let mut req = vec![0u8; 20];
req[0..2].copy_from_slice(&0x0001u16.to_be_bytes()); // Binding Request
req[2..4].copy_from_slice(&0u16.to_be_bytes()); // length
req[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes()); // magic cookie
rand::rng().fill_bytes(&mut req[8..20]);
socket
.send(&req)
.await
.map_err(|e| ProxyError::Proxy(format!("STUN send failed: {e}")))?;
let mut buf = [0u8; 128];
let n = socket
.recv(&mut buf)
.await
.map_err(|e| ProxyError::Proxy(format!("STUN recv failed: {e}")))?;
if n < 20 {
return Ok(None);
}
// Parse attributes.
let mut idx = 20;
while idx + 4 <= n {
let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().unwrap());
let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().unwrap()) as usize;
idx += 4;
if idx + alen > n {
break;
}
match atype {
0x0020 /* XOR-MAPPED-ADDRESS */ | 0x0001 /* MAPPED-ADDRESS */ => {
if alen < 8 {
break;
}
let family = buf[idx + 1];
if family != 0x01 {
// only IPv4 supported here
break;
}
let port_bytes = [buf[idx + 2], buf[idx + 3]];
let ip_bytes = [buf[idx + 4], buf[idx + 5], buf[idx + 6], buf[idx + 7]];
let (port, ip) = if atype == 0x0020 {
let magic = 0x2112A442u32.to_be_bytes();
let port = u16::from_be_bytes(port_bytes) ^ ((magic[0] as u16) << 8 | magic[1] as u16);
let ip = [
ip_bytes[0] ^ magic[0],
ip_bytes[1] ^ magic[1],
ip_bytes[2] ^ magic[2],
ip_bytes[3] ^ magic[3],
];
(port, ip)
} else {
(u16::from_be_bytes(port_bytes), ip_bytes)
};
let reflected = std::net::SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(ip[0], ip[1], ip[2], ip[3])),
port,
);
let local_addr = socket.local_addr().map_err(|e| {
ProxyError::Proxy(format!("STUN local_addr failed: {e}"))
})?;
return Ok(Some(StunProbeResult {
local_addr,
reflected_addr: reflected,
}));
}
_ => {}
}
idx += (alen + 3) & !3; // 4-byte alignment
}
Ok(None)
}
fn is_privateish(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => v4.is_private() || v4.is_link_local(),
IpAddr::V6(v6) => v6.is_unique_local(),
}
}

View File

@@ -152,6 +152,9 @@ pub(crate) async fn reader_loop(
entry.1 = entry.1 * 0.8 + rtt * 0.2;
if rtt < entry.0 {
entry.0 = rtt;
} else {
// allow slow baseline drift upward to avoid stale minimum
entry.0 = entry.0 * 0.99 + rtt * 0.01;
}
let degraded_now = entry.1 > entry.0 * 2.0;
degraded.store(degraded_now, Ordering::Relaxed);

View File

@@ -1,4 +1,5 @@
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::time::Duration;
use tracing::{info, warn};
@@ -15,7 +16,12 @@ pub async fn me_rotation_task(pool: Arc<MePool>, rng: Arc<SecureRandom>, interva
let candidate = {
let ws = pool.writers.read().await;
ws.get(0).cloned()
if ws.is_empty() {
None
} else {
let idx = (pool.rr.load(std::sync::atomic::Ordering::Relaxed) as usize) % ws.len();
ws.get(idx).cloned()
}
};
let Some(w) = candidate else {
@@ -34,4 +40,3 @@ pub async fn me_rotation_task(pool: Arc<MePool>, rng: Arc<SecureRandom>, interva
}
}
}

View File

@@ -3,15 +3,14 @@ use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::time::Duration;
use tokio::sync::Mutex;
use tracing::{debug, warn};
use crate::error::{ProxyError, Result};
use crate::network::IpFamily;
use crate::protocol::constants::RPC_CLOSE_EXT_U32;
use super::MePool;
use super::wire::build_proxy_req_payload;
use crate::crypto::SecureRandom;
use rand::seq::SliceRandom;
use super::registry::ConnMeta;
@@ -84,7 +83,7 @@ impl MePool {
drop(map);
for (ip, port) in shuffled {
let addr = SocketAddr::new(ip, port);
if self.connect_one(addr, &SecureRandom::new()).await.is_ok() {
if self.connect_one(addr, self.rng.as_ref()).await.is_ok() {
break;
}
}
@@ -173,32 +172,44 @@ impl MePool {
writers: &[super::pool::MeWriter],
target_dc: i16,
) -> Vec<usize> {
let mut preferred = Vec::<SocketAddr>::new();
let key = target_dc as i32;
let map = self.proxy_map_v4.read().await;
let mut preferred = Vec::<SocketAddr>::new();
if let Some(v) = map.get(&key) {
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port)));
}
if preferred.is_empty() {
let abs = key.abs();
if let Some(v) = map.get(&abs) {
for family in self.family_order() {
let map_guard = match family {
IpFamily::V4 => self.proxy_map_v4.read().await,
IpFamily::V6 => self.proxy_map_v6.read().await,
};
if let Some(v) = map_guard.get(&key) {
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port)));
}
}
if preferred.is_empty() {
let abs = key.abs();
if let Some(v) = map.get(&-abs) {
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port)));
}
}
if preferred.is_empty() {
let def = self.default_dc.load(Ordering::Relaxed);
if def != 0 {
if let Some(v) = map.get(&def) {
if preferred.is_empty() {
let abs = key.abs();
if let Some(v) = map_guard.get(&abs) {
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port)));
}
}
if preferred.is_empty() {
let abs = key.abs();
if let Some(v) = map_guard.get(&-abs) {
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port)));
}
}
if preferred.is_empty() {
let def = self.default_dc.load(Ordering::Relaxed);
if def != 0 {
if let Some(v) = map_guard.get(&def) {
preferred.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port)));
}
}
}
drop(map_guard);
if !preferred.is_empty() && !self.decision.effective_multipath {
break;
}
}
if preferred.is_empty() {

View File

@@ -355,6 +355,8 @@ impl UpstreamManager {
&self,
prefer_ipv6: bool,
dc_overrides: &HashMap<String, Vec<String>>,
ipv4_enabled: bool,
ipv6_enabled: bool,
) -> Vec<StartupPingResult> {
let upstreams: Vec<(usize, UpstreamConfig)> = {
let guard = self.upstreams.read().await;
@@ -374,85 +376,106 @@ impl UpstreamManager {
UpstreamType::Socks5 { address, .. } => format!("socks5://{}", address),
};
let mut v6_results = Vec::new();
let mut v4_results = Vec::new();
let mut v6_results = Vec::with_capacity(NUM_DCS);
if ipv6_enabled {
for dc_zero_idx in 0..NUM_DCS {
let dc_v6 = TG_DATACENTERS_V6[dc_zero_idx];
let addr_v6 = SocketAddr::new(dc_v6, TG_DATACENTER_PORT);
// === Ping IPv6 first ===
for dc_zero_idx in 0..NUM_DCS {
let dc_v6 = TG_DATACENTERS_V6[dc_zero_idx];
let addr_v6 = SocketAddr::new(dc_v6, TG_DATACENTER_PORT);
let result = tokio::time::timeout(
Duration::from_secs(DC_PING_TIMEOUT_SECS),
self.ping_single_dc(&upstream_config, addr_v6)
).await;
let result = tokio::time::timeout(
Duration::from_secs(DC_PING_TIMEOUT_SECS),
self.ping_single_dc(&upstream_config, addr_v6)
).await;
let ping_result = match result {
Ok(Ok(rtt_ms)) => {
let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(*upstream_idx) {
u.dc_latency[dc_zero_idx].update(rtt_ms);
let ping_result = match result {
Ok(Ok(rtt_ms)) => {
let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(*upstream_idx) {
u.dc_latency[dc_zero_idx].update(rtt_ms);
}
DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v6,
rtt_ms: Some(rtt_ms),
error: None,
}
}
DcPingResult {
Ok(Err(e)) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v6,
rtt_ms: Some(rtt_ms),
error: None,
}
}
Ok(Err(e)) => DcPingResult {
rtt_ms: None,
error: Some(e.to_string()),
},
Err(_) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v6,
rtt_ms: None,
error: Some("timeout".to_string()),
},
};
v6_results.push(ping_result);
}
} else {
for dc_zero_idx in 0..NUM_DCS {
let dc_v6 = TG_DATACENTERS_V6[dc_zero_idx];
v6_results.push(DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v6,
dc_addr: SocketAddr::new(dc_v6, TG_DATACENTER_PORT),
rtt_ms: None,
error: Some(e.to_string()),
},
Err(_) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v6,
rtt_ms: None,
error: Some("timeout".to_string()),
},
};
v6_results.push(ping_result);
error: Some("ipv6 disabled".to_string()),
});
}
}
// === Then ping IPv4 ===
for dc_zero_idx in 0..NUM_DCS {
let dc_v4 = TG_DATACENTERS_V4[dc_zero_idx];
let addr_v4 = SocketAddr::new(dc_v4, TG_DATACENTER_PORT);
let mut v4_results = Vec::with_capacity(NUM_DCS);
if ipv4_enabled {
for dc_zero_idx in 0..NUM_DCS {
let dc_v4 = TG_DATACENTERS_V4[dc_zero_idx];
let addr_v4 = SocketAddr::new(dc_v4, TG_DATACENTER_PORT);
let result = tokio::time::timeout(
Duration::from_secs(DC_PING_TIMEOUT_SECS),
self.ping_single_dc(&upstream_config, addr_v4)
).await;
let result = tokio::time::timeout(
Duration::from_secs(DC_PING_TIMEOUT_SECS),
self.ping_single_dc(&upstream_config, addr_v4)
).await;
let ping_result = match result {
Ok(Ok(rtt_ms)) => {
let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(*upstream_idx) {
u.dc_latency[dc_zero_idx].update(rtt_ms);
let ping_result = match result {
Ok(Ok(rtt_ms)) => {
let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(*upstream_idx) {
u.dc_latency[dc_zero_idx].update(rtt_ms);
}
DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v4,
rtt_ms: Some(rtt_ms),
error: None,
}
}
DcPingResult {
Ok(Err(e)) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v4,
rtt_ms: Some(rtt_ms),
error: None,
}
}
Ok(Err(e)) => DcPingResult {
rtt_ms: None,
error: Some(e.to_string()),
},
Err(_) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v4,
rtt_ms: None,
error: Some("timeout".to_string()),
},
};
v4_results.push(ping_result);
}
} else {
for dc_zero_idx in 0..NUM_DCS {
let dc_v4 = TG_DATACENTERS_V4[dc_zero_idx];
v4_results.push(DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v4,
dc_addr: SocketAddr::new(dc_v4, TG_DATACENTER_PORT),
rtt_ms: None,
error: Some(e.to_string()),
},
Err(_) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr: addr_v4,
rtt_ms: None,
error: Some("timeout".to_string()),
},
};
v4_results.push(ping_result);
error: Some("ipv4 disabled".to_string()),
});
}
}
// === Ping DC overrides (v4/v6) ===
@@ -470,6 +493,9 @@ impl UpstreamManager {
match addr_str.parse::<SocketAddr>() {
Ok(addr) => {
let is_v6 = addr.is_ipv6();
if (is_v6 && !ipv6_enabled) || (!is_v6 && !ipv4_enabled) {
continue;
}
let result = tokio::time::timeout(
Duration::from_secs(DC_PING_TIMEOUT_SECS),
self.ping_single_dc(&upstream_config, addr)
@@ -551,7 +577,7 @@ impl UpstreamManager {
/// Background health check: rotates through DCs, 30s interval.
/// Uses preferred IP version based on config.
pub async fn run_health_checks(&self, prefer_ipv6: bool) {
pub async fn run_health_checks(&self, prefer_ipv6: bool, ipv4_enabled: bool, ipv6_enabled: bool) {
let mut dc_rotation = 0usize;
loop {
@@ -560,16 +586,24 @@ impl UpstreamManager {
let dc_zero_idx = dc_rotation % NUM_DCS;
dc_rotation += 1;
let dc_addr = if prefer_ipv6 {
SocketAddr::new(TG_DATACENTERS_V6[dc_zero_idx], TG_DATACENTER_PORT)
let primary_v6 = SocketAddr::new(TG_DATACENTERS_V6[dc_zero_idx], TG_DATACENTER_PORT);
let primary_v4 = SocketAddr::new(TG_DATACENTERS_V4[dc_zero_idx], TG_DATACENTER_PORT);
let dc_addr = if prefer_ipv6 && ipv6_enabled {
primary_v6
} else if ipv4_enabled {
primary_v4
} else if ipv6_enabled {
primary_v6
} else {
SocketAddr::new(TG_DATACENTERS_V4[dc_zero_idx], TG_DATACENTER_PORT)
continue;
};
let fallback_addr = if prefer_ipv6 {
SocketAddr::new(TG_DATACENTERS_V4[dc_zero_idx], TG_DATACENTER_PORT)
let fallback_addr = if dc_addr.is_ipv6() && ipv4_enabled {
Some(primary_v4)
} else if dc_addr.is_ipv4() && ipv6_enabled {
Some(primary_v6)
} else {
SocketAddr::new(TG_DATACENTERS_V6[dc_zero_idx], TG_DATACENTER_PORT)
None
};
let count = self.upstreams.read().await.len();
@@ -608,48 +642,60 @@ impl UpstreamManager {
// Try fallback
debug!(dc = dc_zero_idx + 1, "Health check failed, trying fallback");
let start2 = Instant::now();
let result2 = tokio::time::timeout(
Duration::from_secs(10),
self.connect_via_upstream(&config, fallback_addr)
).await;
if let Some(fallback_addr) = fallback_addr {
let start2 = Instant::now();
let result2 = tokio::time::timeout(
Duration::from_secs(10),
self.connect_via_upstream(&config, fallback_addr)
).await;
let mut guard = self.upstreams.write().await;
let u = &mut guard[i];
match result2 {
Ok(Ok(_stream)) => {
let rtt_ms = start2.elapsed().as_secs_f64() * 1000.0;
u.dc_latency[dc_zero_idx].update(rtt_ms);
if !u.healthy {
info!(
rtt = format!("{:.0} ms", rtt_ms),
dc = dc_zero_idx + 1,
"Upstream recovered (fallback)"
);
}
u.healthy = true;
u.fails = 0;
}
Ok(Err(e)) => {
u.fails += 1;
debug!(dc = dc_zero_idx + 1, fails = u.fails,
"Health check failed (both): {}", e);
if u.fails > 3 {
u.healthy = false;
warn!("Upstream unhealthy (fails)");
}
}
Err(_) => {
u.fails += 1;
debug!(dc = dc_zero_idx + 1, fails = u.fails,
"Health check timeout (both)");
if u.fails > 3 {
u.healthy = false;
warn!("Upstream unhealthy (timeout)");
}
}
}
u.last_check = std::time::Instant::now();
continue;
}
let mut guard = self.upstreams.write().await;
let u = &mut guard[i];
match result2 {
Ok(Ok(_stream)) => {
let rtt_ms = start2.elapsed().as_secs_f64() * 1000.0;
u.dc_latency[dc_zero_idx].update(rtt_ms);
if !u.healthy {
info!(
rtt = format!("{:.0} ms", rtt_ms),
dc = dc_zero_idx + 1,
"Upstream recovered (fallback)"
);
}
u.healthy = true;
u.fails = 0;
}
Ok(Err(e)) => {
u.fails += 1;
debug!(dc = dc_zero_idx + 1, fails = u.fails,
"Health check failed (both): {}", e);
if u.fails > 3 {
u.healthy = false;
warn!("Upstream unhealthy (fails)");
}
}
Err(_) => {
u.fails += 1;
debug!(dc = dc_zero_idx + 1, fails = u.fails,
"Health check timeout (both)");
if u.fails > 3 {
u.healthy = false;
warn!("Upstream unhealthy (timeout)");
}
}
u.fails += 1;
if u.fails > 3 {
u.healthy = false;
warn!("Upstream unhealthy (no fallback family)");
}
u.last_check = std::time::Instant::now();
}