From e4f90cd7c123de1192124342ee1c649f2a6c3ac1 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Mon, 16 Feb 2026 12:10:59 +0300 Subject: [PATCH] ME Ping in log --- src/main.rs | 70 ++++- src/transport/middle_proxy/handshake.rs | 369 ++++++++++++++++++++++++ src/transport/middle_proxy/mod.rs | 3 + src/transport/middle_proxy/ping.rs | 164 +++++++++++ src/transport/middle_proxy/pool.rs | 343 +--------------------- 5 files changed, 619 insertions(+), 330 deletions(-) create mode 100644 src/transport/middle_proxy/handshake.rs create mode 100644 src/transport/middle_proxy/ping.rs diff --git a/src/main.rs b/src/main.rs index 48bd45f..882ecdb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -29,7 +29,10 @@ use crate::ip_tracker::UserIpTracker; use crate::proxy::ClientHandler; use crate::stats::{ReplayChecker, Stats}; use crate::stream::BufferPool; -use crate::transport::middle_proxy::{MePool, fetch_proxy_config, stun_probe}; +use crate::transport::middle_proxy::{ + MePool, fetch_proxy_config, run_me_ping, MePingFamily, MePingSample, format_sample_line, + stun_probe, +}; use crate::transport::{ListenOptions, UpstreamManager, create_listener}; use crate::util::ip::detect_ip; use crate::protocol::constants::{TG_MIDDLE_PROXIES_V4, TG_MIDDLE_PROXIES_V6}; @@ -388,6 +391,71 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai info!("Transport: Direct TCP (standard DCs only)"); } + // Middle-End ping before DC connectivity + if let Some(ref pool) = me_pool { + let me_results = run_me_ping(pool, &rng).await; + + let v4_ok = me_results.iter().any(|r| { + matches!(r.family, MePingFamily::V4) + && r.samples.iter().any(|s| s.error.is_none() && s.handshake_ms.is_some()) + }); + let v6_ok = me_results.iter().any(|r| { + matches!(r.family, MePingFamily::V6) + && r.samples.iter().any(|s| s.error.is_none() && s.handshake_ms.is_some()) + }); + + info!("================= Telegram ME Connectivity ================="); + if v4_ok && v6_ok { + info!(" IPv4 and IPv6 available"); + } else if v4_ok { + info!(" IPv4 only / IPv6 unavailable"); + } else if v6_ok { + info!(" IPv6 only / IPv4 unavailable"); + } else { + info!(" No ME connectivity"); + } + info!(" via direct"); + info!("============================================================"); + + use std::collections::BTreeMap; + let mut grouped: BTreeMap> = BTreeMap::new(); + for report in me_results { + for s in report.samples { + let key = s.dc.abs(); + grouped.entry(key).or_default().push(s); + } + } + + let family_order = if prefer_ipv6 { + vec![(MePingFamily::V6, true), (MePingFamily::V6, false), (MePingFamily::V4, true), (MePingFamily::V4, false)] + } else { + vec![(MePingFamily::V4, true), (MePingFamily::V4, false), (MePingFamily::V6, true), (MePingFamily::V6, false)] + }; + + for (dc_abs, samples) in grouped { + for (family, is_pos) in &family_order { + let fam_samples: Vec<&MePingSample> = samples + .iter() + .filter(|s| matches!(s.family, f if &f == family) && (s.dc >= 0) == *is_pos) + .collect(); + if fam_samples.is_empty() { + continue; + } + + let fam_label = match family { + MePingFamily::V4 => "IPv4", + MePingFamily::V6 => "IPv6", + }; + info!(" DC{} [{}]", dc_abs, fam_label); + for sample in fam_samples { + let line = format_sample_line(sample); + info!("{}", line); + } + } + } + info!("============================================================"); + } + info!("================= Telegram DC Connectivity ================="); let ping_results = upstream_manager diff --git a/src/transport/middle_proxy/handshake.rs b/src/transport/middle_proxy/handshake.rs new file mode 100644 index 0000000..fdd4d15 --- /dev/null +++ b/src/transport/middle_proxy/handshake.rs @@ -0,0 +1,369 @@ +use std::net::{IpAddr, SocketAddr}; +use std::time::{Duration, Instant}; + +use bytes::BytesMut; +use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}; +use tokio::net::TcpStream; +use tokio::time::timeout; +use tracing::{debug, info, warn}; + +use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_keys, sha256}; +use crate::error::{ProxyError, Result}; +use crate::protocol::constants::{ + ME_CONNECT_TIMEOUT_SECS, ME_HANDSHAKE_TIMEOUT_SECS, RPC_CRYPTO_AES_U32, RPC_HANDSHAKE_ERROR_U32, + RPC_HANDSHAKE_U32, RPC_PING_U32, RPC_PONG_U32, RPC_NONCE_U32, +}; + +use super::codec::{ + build_handshake_payload, build_nonce_payload, build_rpc_frame, cbc_decrypt_inplace, + cbc_encrypt_padded, parse_nonce_payload, read_rpc_frame_plaintext, +}; +use super::wire::{extract_ip_material, IpMaterial}; +use super::MePool; + +/// Result of a successful ME handshake with timings. +pub(crate) struct HandshakeOutput { + pub rd: ReadHalf, + pub wr: WriteHalf, + pub read_key: [u8; 32], + pub read_iv: [u8; 16], + pub write_key: [u8; 32], + pub write_iv: [u8; 16], + pub handshake_ms: f64, +} + +impl MePool { + /// TCP connect with timeout + return RTT in milliseconds. + pub(crate) async fn connect_tcp(&self, addr: SocketAddr) -> Result<(TcpStream, f64)> { + let start = Instant::now(); + let stream = timeout(Duration::from_secs(ME_CONNECT_TIMEOUT_SECS), TcpStream::connect(addr)) + .await + .map_err(|_| ProxyError::ConnectionTimeout { addr: addr.to_string() })??; + let connect_ms = start.elapsed().as_secs_f64() * 1000.0; + stream.set_nodelay(true).ok(); + Ok((stream, connect_ms)) + } + + /// Perform full ME RPC handshake on an established TCP stream. + /// Returns cipher keys/ivs and split halves; does not register writer. + pub(crate) async fn handshake_only( + &self, + stream: TcpStream, + addr: SocketAddr, + rng: &SecureRandom, + ) -> Result { + let hs_start = Instant::now(); + + let local_addr = stream.local_addr().map_err(ProxyError::Io)?; + let peer_addr = stream.peer_addr().map_err(ProxyError::Io)?; + + let _ = self.maybe_detect_nat_ip(local_addr.ip()).await; + let reflected = if self.nat_probe { + self.maybe_reflect_public_addr().await + } else { + None + }; + + let local_addr_nat = self.translate_our_addr_with_reflection(local_addr, reflected); + let peer_addr_nat = SocketAddr::new(self.translate_ip_for_nat(peer_addr.ip()), peer_addr.port()); + let (mut rd, mut wr) = tokio::io::split(stream); + + let my_nonce: [u8; 16] = rng.bytes(16).try_into().unwrap(); + let crypto_ts = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() as u32; + + let ks = self.key_selector().await; + let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce); + let nonce_frame = build_rpc_frame(-2, &nonce_payload); + let dump = hex_dump(&nonce_frame[..nonce_frame.len().min(44)]); + info!( + key_selector = format_args!("0x{ks:08x}"), + crypto_ts, + frame_len = nonce_frame.len(), + nonce_frame_hex = %dump, + "Sending ME nonce frame" + ); + wr.write_all(&nonce_frame).await.map_err(ProxyError::Io)?; + wr.flush().await.map_err(ProxyError::Io)?; + + let (srv_seq, srv_nonce_payload) = timeout( + Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS), + read_rpc_frame_plaintext(&mut rd), + ) + .await + .map_err(|_| ProxyError::TgHandshakeTimeout)??; + + if srv_seq != -2 { + return Err(ProxyError::InvalidHandshake(format!("Expected seq=-2, got {srv_seq}"))); + } + + let (srv_key_select, schema, srv_ts, srv_nonce) = parse_nonce_payload(&srv_nonce_payload)?; + if schema != RPC_CRYPTO_AES_U32 { + warn!(schema = format_args!("0x{schema:08x}"), "Unsupported ME crypto schema"); + return Err(ProxyError::InvalidHandshake(format!( + "Unsupported crypto schema: 0x{schema:x}" + ))); + } + + if srv_key_select != ks { + return Err(ProxyError::InvalidHandshake(format!( + "Server key_select 0x{srv_key_select:08x} != client 0x{ks:08x}" + ))); + } + + let skew = crypto_ts.abs_diff(srv_ts); + if skew > 30 { + return Err(ProxyError::InvalidHandshake(format!( + "nonce crypto_ts skew too large: client={crypto_ts}, server={srv_ts}, skew={skew}s" + ))); + } + + info!( + %local_addr, + %local_addr_nat, + reflected_ip = reflected.map(|r| r.ip()).as_ref().map(ToString::to_string), + %peer_addr, + %peer_addr_nat, + key_selector = format_args!("0x{ks:08x}"), + crypto_schema = format_args!("0x{schema:08x}"), + skew_secs = skew, + "ME key derivation parameters" + ); + + let ts_bytes = crypto_ts.to_le_bytes(); + let server_port_bytes = peer_addr_nat.port().to_le_bytes(); + let client_port_bytes = local_addr_nat.port().to_le_bytes(); + + let server_ip = extract_ip_material(peer_addr_nat); + let client_ip = extract_ip_material(local_addr_nat); + + let (srv_ip_opt, clt_ip_opt, clt_v6_opt, srv_v6_opt, hs_our_ip, hs_peer_ip) = match (server_ip, client_ip) { + (IpMaterial::V4(mut srv), IpMaterial::V4(mut clt)) => { + srv.reverse(); + clt.reverse(); + (Some(srv), Some(clt), None, None, clt, srv) + } + (IpMaterial::V6(srv), IpMaterial::V6(clt)) => { + let zero = [0u8; 4]; + (None, None, Some(clt), Some(srv), zero, zero) + } + _ => { + return Err(ProxyError::InvalidHandshake( + "mixed IPv4/IPv6 endpoints are not supported for ME key derivation".to_string(), + )); + } + }; + + let diag_level: u8 = std::env::var("ME_DIAG").ok().and_then(|v| v.parse().ok()).unwrap_or(0); + + let secret: Vec = self.proxy_secret.read().await.clone(); + + let prekey_client = build_middleproxy_prekey( + &srv_nonce, + &my_nonce, + &ts_bytes, + srv_ip_opt.as_ref().map(|x| &x[..]), + &client_port_bytes, + b"CLIENT", + clt_ip_opt.as_ref().map(|x| &x[..]), + &server_port_bytes, + &secret, + clt_v6_opt.as_ref(), + srv_v6_opt.as_ref(), + ); + let prekey_server = build_middleproxy_prekey( + &srv_nonce, + &my_nonce, + &ts_bytes, + srv_ip_opt.as_ref().map(|x| &x[..]), + &client_port_bytes, + b"SERVER", + clt_ip_opt.as_ref().map(|x| &x[..]), + &server_port_bytes, + &secret, + clt_v6_opt.as_ref(), + srv_v6_opt.as_ref(), + ); + + let (wk, wi) = derive_middleproxy_keys( + &srv_nonce, + &my_nonce, + &ts_bytes, + srv_ip_opt.as_ref().map(|x| &x[..]), + &client_port_bytes, + b"CLIENT", + clt_ip_opt.as_ref().map(|x| &x[..]), + &server_port_bytes, + &secret, + clt_v6_opt.as_ref(), + srv_v6_opt.as_ref(), + ); + let (rk, ri) = derive_middleproxy_keys( + &srv_nonce, + &my_nonce, + &ts_bytes, + srv_ip_opt.as_ref().map(|x| &x[..]), + &client_port_bytes, + b"SERVER", + clt_ip_opt.as_ref().map(|x| &x[..]), + &server_port_bytes, + &secret, + clt_v6_opt.as_ref(), + srv_v6_opt.as_ref(), + ); + + let hs_payload = build_handshake_payload(hs_our_ip, local_addr.port(), hs_peer_ip, peer_addr.port()); + let hs_frame = build_rpc_frame(-1, &hs_payload); + if diag_level >= 1 { + info!( + write_key = %hex_dump(&wk), + write_iv = %hex_dump(&wi), + read_key = %hex_dump(&rk), + read_iv = %hex_dump(&ri), + srv_ip = %srv_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(), + clt_ip = %clt_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(), + srv_port = %hex_dump(&server_port_bytes), + clt_port = %hex_dump(&client_port_bytes), + crypto_ts = %hex_dump(&ts_bytes), + nonce_srv = %hex_dump(&srv_nonce), + nonce_clt = %hex_dump(&my_nonce), + prekey_sha256_client = %hex_dump(&sha256(&prekey_client)), + prekey_sha256_server = %hex_dump(&sha256(&prekey_server)), + hs_plain = %hex_dump(&hs_frame), + proxy_secret_sha256 = %hex_dump(&sha256(&secret)), + "ME diag: derived keys and handshake plaintext" + ); + } + if diag_level >= 2 { + info!( + prekey_client = %hex_dump(&prekey_client), + prekey_server = %hex_dump(&prekey_server), + "ME diag: full prekey buffers" + ); + } + + let (encrypted_hs, mut write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?; + if diag_level >= 1 { + info!( + hs_cipher = %hex_dump(&encrypted_hs), + "ME diag: handshake ciphertext" + ); + } + wr.write_all(&encrypted_hs).await.map_err(ProxyError::Io)?; + wr.flush().await.map_err(ProxyError::Io)?; + + let deadline = Instant::now() + Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS); + let mut enc_buf = BytesMut::with_capacity(256); + let mut dec_buf = BytesMut::with_capacity(256); + let mut read_iv = ri; + let mut handshake_ok = false; + + while Instant::now() < deadline && !handshake_ok { + let remaining = deadline - Instant::now(); + let mut tmp = [0u8; 256]; + let n = match timeout(remaining, rd.read(&mut tmp)).await { + Ok(Ok(0)) => { + return Err(ProxyError::Io(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "ME closed during handshake", + ))); + } + Ok(Ok(n)) => n, + Ok(Err(e)) => return Err(ProxyError::Io(e)), + Err(_) => return Err(ProxyError::TgHandshakeTimeout), + }; + + enc_buf.extend_from_slice(&tmp[..n]); + + let blocks = enc_buf.len() / 16 * 16; + if blocks > 0 { + let mut chunk = vec![0u8; blocks]; + chunk.copy_from_slice(&enc_buf[..blocks]); + read_iv = cbc_decrypt_inplace(&rk, &read_iv, &mut chunk)?; + dec_buf.extend_from_slice(&chunk); + let _ = enc_buf.split_to(blocks); + } + + while dec_buf.len() >= 4 { + let fl = u32::from_le_bytes(dec_buf[0..4].try_into().unwrap()) as usize; + + if fl == 4 { + let _ = dec_buf.split_to(4); + continue; + } + if !(12..=(1 << 24)).contains(&fl) { + return Err(ProxyError::InvalidHandshake(format!( + "Bad HS response frame len: {fl}" + ))); + } + if dec_buf.len() < fl { + break; + } + + let frame = dec_buf.split_to(fl); + let pe = fl - 4; + let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); + let ac = crate::crypto::crc32(&frame[..pe]); + if ec != ac { + return Err(ProxyError::InvalidHandshake(format!( + "HS CRC mismatch: 0x{ec:08x} vs 0x{ac:08x}" + ))); + } + + let hs_type = u32::from_le_bytes(frame[8..12].try_into().unwrap()); + if hs_type == RPC_HANDSHAKE_ERROR_U32 { + let err_code = if frame.len() >= 16 { + i32::from_le_bytes(frame[12..16].try_into().unwrap()) + } else { + -1 + }; + return Err(ProxyError::InvalidHandshake(format!( + "ME rejected handshake (error={err_code})" + ))); + } + if hs_type != RPC_HANDSHAKE_U32 { + return Err(ProxyError::InvalidHandshake(format!( + "Expected HANDSHAKE 0x{RPC_HANDSHAKE_U32:08x}, got 0x{hs_type:08x}" + ))); + } + + handshake_ok = true; + break; + } + } + + if !handshake_ok { + return Err(ProxyError::TgHandshakeTimeout); + } + + let handshake_ms = hs_start.elapsed().as_secs_f64() * 1000.0; + info!(%addr, "RPC handshake OK"); + + Ok(HandshakeOutput { + rd, + wr, + read_key: rk, + read_iv, + write_key: wk, + write_iv, + handshake_ms, + }) + } +} + +fn hex_dump(data: &[u8]) -> String { + const MAX: usize = 64; + let mut out = String::with_capacity(data.len() * 2 + 3); + for (i, b) in data.iter().take(MAX).enumerate() { + if i > 0 { + out.push(' '); + } + out.push_str(&format!("{b:02x}")); + } + if data.len() > MAX { + out.push_str(" …"); + } + out +} diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs index 72c0c24..443c189 100644 --- a/src/transport/middle_proxy/mod.rs +++ b/src/transport/middle_proxy/mod.rs @@ -1,9 +1,11 @@ //! Middle Proxy RPC transport. mod codec; +mod handshake; mod health; mod pool; mod pool_nat; +mod ping; mod reader; mod registry; mod send; @@ -14,6 +16,7 @@ mod wire; use bytes::Bytes; pub use health::me_health_monitor; +pub use ping::{run_me_ping, format_sample_line, MePingReport, MePingSample, MePingFamily}; pub use pool::MePool; pub use pool_nat::{stun_probe, StunProbeResult}; pub use registry::ConnRegistry; diff --git a/src/transport/middle_proxy/ping.rs b/src/transport/middle_proxy/ping.rs new file mode 100644 index 0000000..22b1f6d --- /dev/null +++ b/src/transport/middle_proxy/ping.rs @@ -0,0 +1,164 @@ +use std::net::{IpAddr, SocketAddr}; +use std::sync::Arc; + +use crate::crypto::SecureRandom; +use crate::error::ProxyError; + +use super::MePool; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MePingFamily { + V4, + V6, +} + +#[derive(Debug, Clone)] +pub struct MePingSample { + pub dc: i32, + pub addr: SocketAddr, + pub connect_ms: Option, + pub handshake_ms: Option, + pub error: Option, + pub family: MePingFamily, +} + +#[derive(Debug, Clone)] +pub struct MePingReport { + pub dc: i32, + pub family: MePingFamily, + pub samples: Vec, +} + +pub fn format_sample_line(sample: &MePingSample) -> String { + let sign = if sample.dc >= 0 { "+" } else { "-" }; + let addr = format!("{}:{}", sample.addr.ip(), sample.addr.port()); + + match (sample.connect_ms, sample.handshake_ms.as_ref(), sample.error.as_ref()) { + (Some(conn), Some(hs), None) => format!( + " {sign} {addr}\tPing: {:.0} ms / RPC: {:.0} ms / OK", + conn, hs + ), + (Some(conn), None, Some(err)) => format!( + " {sign} {addr}\tPing: {:.0} ms / RPC: FAIL ({err})", + conn + ), + (None, _, Some(err)) => format!(" {sign} {addr}\tPing: FAIL ({err})"), + (Some(conn), None, None) => format!(" {sign} {addr}\tPing: {:.0} ms / RPC: FAIL", conn), + _ => format!(" {sign} {addr}\tPing: FAIL"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + + fn sample(base: MePingSample) -> MePingSample { + base + } + + #[test] + fn ok_line_contains_both_timings() { + let s = sample(MePingSample { + dc: 4, + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), 8888), + connect_ms: Some(12.3), + handshake_ms: Some(34.7), + error: None, + family: MePingFamily::V4, + }); + let line = format_sample_line(&s); + assert!(line.contains("Ping: 12 ms")); + assert!(line.contains("RPC: 35 ms")); + assert!(line.contains("OK")); + } + + #[test] + fn error_line_mentions_reason() { + let s = sample(MePingSample { + dc: -5, + addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8)), 80), + connect_ms: Some(10.0), + handshake_ms: None, + error: Some("handshake timeout".to_string()), + family: MePingFamily::V4, + }); + let line = format_sample_line(&s); + assert!(line.contains("- 5.6.7.8:80")); + assert!(line.contains("handshake timeout")); + } +} + +pub async fn run_me_ping(pool: &Arc, rng: &SecureRandom) -> Vec { + let mut reports = Vec::new(); + + let v4_map = pool.proxy_map_v4.read().await.clone(); + let v6_map = pool.proxy_map_v6.read().await.clone(); + + let mut grouped: Vec<(MePingFamily, i32, Vec<(IpAddr, u16)>)> = Vec::new(); + for (dc, addrs) in v4_map { + grouped.push((MePingFamily::V4, dc, addrs)); + } + for (dc, addrs) in v6_map { + grouped.push((MePingFamily::V6, dc, addrs)); + } + + for (family, dc, addrs) in grouped { + let mut samples = Vec::new(); + for (ip, port) in addrs { + let addr = SocketAddr::new(ip, port); + let mut connect_ms = None; + let mut handshake_ms = None; + let mut error = None; + + match pool.connect_tcp(addr).await { + Ok((stream, conn_rtt)) => { + connect_ms = Some(conn_rtt); + match pool.handshake_only(stream, addr, rng).await { + Ok(hs) => { + handshake_ms = Some(hs.handshake_ms); + // drop halves to close + drop(hs.rd); + drop(hs.wr); + } + Err(e) => { + error = Some(short_err(&e)); + } + } + } + Err(e) => { + error = Some(short_err(&e)); + } + } + + samples.push(MePingSample { + dc, + addr, + connect_ms, + handshake_ms, + error, + family, + }); + } + + reports.push(MePingReport { + dc, + family, + samples, + }); + } + + reports +} + +fn short_err(err: &ProxyError) -> String { + match err { + ProxyError::ConnectionTimeout { .. } => "connect timeout".to_string(), + ProxyError::TgHandshakeTimeout => "handshake timeout".to_string(), + ProxyError::InvalidHandshake(e) => format!("bad handshake: {e}"), + ProxyError::Crypto(e) => format!("crypto: {e}"), + ProxyError::Proxy(e) => format!("proxy: {e}"), + ProxyError::Io(e) => format!("io: {e}"), + _ => format!("{err}"), + } +} diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index f38a81c..caa069e 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -2,28 +2,20 @@ use std::collections::HashMap; use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use std::sync::atomic::{AtomicI32, AtomicU64}; -use std::time::Duration; - use bytes::BytesMut; use rand::Rng; use rand::seq::SliceRandom; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::TcpStream; use tokio::sync::{Mutex, RwLock}; -use tokio::time::{Instant, timeout}; use tracing::{debug, info, warn}; +use std::time::Duration; -use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_keys, sha256}; +use crate::crypto::SecureRandom; use crate::error::{ProxyError, Result}; use crate::protocol::constants::*; use super::ConnRegistry; -use super::codec::{ - RpcWriter, build_handshake_payload, build_nonce_payload, build_rpc_frame, cbc_decrypt_inplace, - cbc_encrypt_padded, parse_nonce_payload, read_rpc_frame_plaintext, -}; +use super::codec::RpcWriter; use super::reader::reader_loop; -use super::wire::{IpMaterial, extract_ip_material}; const ME_ACTIVE_PING_SECS: u64 = 25; const ME_ACTIVE_PING_JITTER_SECS: i64 = 5; @@ -157,7 +149,7 @@ impl MePool { // No-op here to avoid total outage. } - async fn key_selector(&self) -> u32 { + pub(super) async fn key_selector(&self) -> u32 { let secret = self.proxy_secret.read().await; if secret.len() >= 4 { u32::from_le_bytes([secret[0], secret[1], secret[2], secret[3]]) @@ -223,326 +215,19 @@ impl MePool { Ok(()) } - pub(crate) async fn connect_one( - &self, - addr: SocketAddr, - rng: &SecureRandom, - ) -> Result<()> { - let secret_guard = self.proxy_secret.read().await; - let secret: Vec = secret_guard.clone(); - if secret.len() < 32 { - return Err(ProxyError::Proxy( - "proxy-secret too short for ME auth".into(), - )); + pub(crate) async fn connect_one(&self, addr: SocketAddr, rng: &SecureRandom) -> Result<()> { + let secret_len = self.proxy_secret.read().await.len(); + if secret_len < 32 { + return Err(ProxyError::Proxy("proxy-secret too short for ME auth".into())); } - let stream = timeout( - Duration::from_secs(ME_CONNECT_TIMEOUT_SECS), - TcpStream::connect(addr), - ) - .await - .map_err(|_| ProxyError::ConnectionTimeout { - addr: addr.to_string(), - })? - .map_err(ProxyError::Io)?; - stream.set_nodelay(true).ok(); - - let local_addr = stream.local_addr().map_err(ProxyError::Io)?; - let peer_addr = stream.peer_addr().map_err(ProxyError::Io)?; - let _ = self.maybe_detect_nat_ip(local_addr.ip()).await; - let reflected = if self.nat_probe { - self.maybe_reflect_public_addr().await - } else { - None - }; - let local_addr_nat = self.translate_our_addr_with_reflection(local_addr, reflected); - let peer_addr_nat = - SocketAddr::new(self.translate_ip_for_nat(peer_addr.ip()), peer_addr.port()); - let (mut rd, mut wr) = tokio::io::split(stream); - - let my_nonce: [u8; 16] = rng.bytes(16).try_into().unwrap(); - let crypto_ts = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_secs() as u32; - - let ks = self.key_selector().await; - let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce); - let nonce_frame = build_rpc_frame(-2, &nonce_payload); - let dump = hex_dump(&nonce_frame[..nonce_frame.len().min(44)]); - info!( - key_selector = format_args!("0x{ks:08x}"), - crypto_ts, - frame_len = nonce_frame.len(), - nonce_frame_hex = %dump, - "Sending ME nonce frame" - ); - wr.write_all(&nonce_frame).await.map_err(ProxyError::Io)?; - wr.flush().await.map_err(ProxyError::Io)?; - - let (srv_seq, srv_nonce_payload) = timeout( - Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS), - read_rpc_frame_plaintext(&mut rd), - ) - .await - .map_err(|_| ProxyError::TgHandshakeTimeout)??; - - if srv_seq != -2 { - return Err(ProxyError::InvalidHandshake(format!( - "Expected seq=-2, got {srv_seq}" - ))); - } - - let (srv_key_select, schema, srv_ts, srv_nonce) = parse_nonce_payload(&srv_nonce_payload)?; - if schema != RPC_CRYPTO_AES_U32 { - warn!(schema = format_args!("0x{schema:08x}"), "Unsupported ME crypto schema"); - return Err(ProxyError::InvalidHandshake(format!( - "Unsupported crypto schema: 0x{schema:x}" - ))); - } - - if srv_key_select != ks { - return Err(ProxyError::InvalidHandshake(format!( - "Server key_select 0x{srv_key_select:08x} != client 0x{ks:08x}" - ))); - } - - let skew = crypto_ts.abs_diff(srv_ts); - if skew > 30 { - return Err(ProxyError::InvalidHandshake(format!( - "nonce crypto_ts skew too large: client={crypto_ts}, server={srv_ts}, skew={skew}s" - ))); - } - - info!( - %local_addr, - %local_addr_nat, - reflected_ip = reflected.map(|r| r.ip()).as_ref().map(ToString::to_string), - %peer_addr, - %peer_addr_nat, - key_selector = format_args!("0x{ks:08x}"), - crypto_schema = format_args!("0x{schema:08x}"), - skew_secs = skew, - "ME key derivation parameters" - ); - - let ts_bytes = crypto_ts.to_le_bytes(); - let server_port_bytes = peer_addr_nat.port().to_le_bytes(); - let client_port_bytes = local_addr_nat.port().to_le_bytes(); - - let server_ip = extract_ip_material(peer_addr_nat); - let client_ip = extract_ip_material(local_addr_nat); - - let (srv_ip_opt, clt_ip_opt, clt_v6_opt, srv_v6_opt, hs_our_ip, hs_peer_ip) = - match (server_ip, client_ip) { - // IPv4: reverse byte order for KDF (Python/C reference behavior) - (IpMaterial::V4(mut srv), IpMaterial::V4(mut clt)) => { - srv.reverse(); - clt.reverse(); - (Some(srv), Some(clt), None, None, clt, srv) - } - (IpMaterial::V6(srv), IpMaterial::V6(clt)) => { - let zero = [0u8; 4]; - (None, None, Some(clt), Some(srv), zero, zero) - } - _ => { - return Err(ProxyError::InvalidHandshake( - "mixed IPv4/IPv6 endpoints are not supported for ME key derivation" - .to_string(), - )); - } - }; - - let diag_level: u8 = std::env::var("ME_DIAG") - .ok() - .and_then(|v| v.parse().ok()) - .unwrap_or(0); - - let prekey_client = build_middleproxy_prekey( - &srv_nonce, - &my_nonce, - &ts_bytes, - srv_ip_opt.as_ref().map(|x| &x[..]), - &client_port_bytes, - b"CLIENT", - clt_ip_opt.as_ref().map(|x| &x[..]), - &server_port_bytes, - &secret, - clt_v6_opt.as_ref(), - srv_v6_opt.as_ref(), - ); - let prekey_server = build_middleproxy_prekey( - &srv_nonce, - &my_nonce, - &ts_bytes, - srv_ip_opt.as_ref().map(|x| &x[..]), - &client_port_bytes, - b"SERVER", - clt_ip_opt.as_ref().map(|x| &x[..]), - &server_port_bytes, - &secret, - clt_v6_opt.as_ref(), - srv_v6_opt.as_ref(), - ); - - let (wk, wi) = derive_middleproxy_keys( - &srv_nonce, - &my_nonce, - &ts_bytes, - srv_ip_opt.as_ref().map(|x| &x[..]), - &client_port_bytes, - b"CLIENT", - clt_ip_opt.as_ref().map(|x| &x[..]), - &server_port_bytes, - &secret, - clt_v6_opt.as_ref(), - srv_v6_opt.as_ref(), - ); - let (rk, ri) = derive_middleproxy_keys( - &srv_nonce, - &my_nonce, - &ts_bytes, - srv_ip_opt.as_ref().map(|x| &x[..]), - &client_port_bytes, - b"SERVER", - clt_ip_opt.as_ref().map(|x| &x[..]), - &server_port_bytes, - &secret, - clt_v6_opt.as_ref(), - srv_v6_opt.as_ref(), - ); - - let hs_payload = - build_handshake_payload(hs_our_ip, local_addr.port(), hs_peer_ip, peer_addr.port()); - let hs_frame = build_rpc_frame(-1, &hs_payload); - if diag_level >= 1 { - info!( - write_key = %hex_dump(&wk), - write_iv = %hex_dump(&wi), - read_key = %hex_dump(&rk), - read_iv = %hex_dump(&ri), - srv_ip = %srv_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(), - clt_ip = %clt_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(), - srv_port = %hex_dump(&server_port_bytes), - clt_port = %hex_dump(&client_port_bytes), - crypto_ts = %hex_dump(&ts_bytes), - nonce_srv = %hex_dump(&srv_nonce), - nonce_clt = %hex_dump(&my_nonce), - prekey_sha256_client = %hex_dump(&sha256(&prekey_client)), - prekey_sha256_server = %hex_dump(&sha256(&prekey_server)), - hs_plain = %hex_dump(&hs_frame), - proxy_secret_sha256 = %hex_dump(&sha256(&secret)), - "ME diag: derived keys and handshake plaintext" - ); - } - if diag_level >= 2 { - info!( - prekey_client = %hex_dump(&prekey_client), - prekey_server = %hex_dump(&prekey_server), - "ME diag: full prekey buffers" - ); - } - - let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?; - if diag_level >= 1 { - info!( - hs_cipher = %hex_dump(&encrypted_hs), - "ME diag: handshake ciphertext" - ); - } - wr.write_all(&encrypted_hs).await.map_err(ProxyError::Io)?; - wr.flush().await.map_err(ProxyError::Io)?; - - let deadline = Instant::now() + Duration::from_secs(ME_HANDSHAKE_TIMEOUT_SECS); - let mut enc_buf = BytesMut::with_capacity(256); - let mut dec_buf = BytesMut::with_capacity(256); - let mut read_iv = ri; - let mut handshake_ok = false; - - while Instant::now() < deadline && !handshake_ok { - let remaining = deadline - Instant::now(); - let mut tmp = [0u8; 256]; - let n = match timeout(remaining, rd.read(&mut tmp)).await { - Ok(Ok(0)) => { - return Err(ProxyError::Io(std::io::Error::new( - std::io::ErrorKind::UnexpectedEof, - "ME closed during handshake", - ))); - } - Ok(Ok(n)) => n, - Ok(Err(e)) => return Err(ProxyError::Io(e)), - Err(_) => return Err(ProxyError::TgHandshakeTimeout), - }; - - enc_buf.extend_from_slice(&tmp[..n]); - - let blocks = enc_buf.len() / 16 * 16; - if blocks > 0 { - let mut chunk = vec![0u8; blocks]; - chunk.copy_from_slice(&enc_buf[..blocks]); - read_iv = cbc_decrypt_inplace(&rk, &read_iv, &mut chunk)?; - dec_buf.extend_from_slice(&chunk); - let _ = enc_buf.split_to(blocks); - } - - while dec_buf.len() >= 4 { - let fl = u32::from_le_bytes(dec_buf[0..4].try_into().unwrap()) as usize; - - if fl == 4 { - let _ = dec_buf.split_to(4); - continue; - } - if !(12..=(1 << 24)).contains(&fl) { - return Err(ProxyError::InvalidHandshake(format!( - "Bad HS response frame len: {fl}" - ))); - } - if dec_buf.len() < fl { - break; - } - - let frame = dec_buf.split_to(fl); - let pe = fl - 4; - let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); - let ac = crate::crypto::crc32(&frame[..pe]); - if ec != ac { - return Err(ProxyError::InvalidHandshake(format!( - "HS CRC mismatch: 0x{ec:08x} vs 0x{ac:08x}" - ))); - } - - let hs_type = u32::from_le_bytes(frame[8..12].try_into().unwrap()); - if hs_type == RPC_HANDSHAKE_ERROR_U32 { - let err_code = if frame.len() >= 16 { - i32::from_le_bytes(frame[12..16].try_into().unwrap()) - } else { - -1 - }; - return Err(ProxyError::InvalidHandshake(format!( - "ME rejected handshake (error={err_code})" - ))); - } - if hs_type != RPC_HANDSHAKE_U32 { - return Err(ProxyError::InvalidHandshake(format!( - "Expected HANDSHAKE 0x{RPC_HANDSHAKE_U32:08x}, got 0x{hs_type:08x}" - ))); - } - - handshake_ok = true; - break; - } - } - - if !handshake_ok { - return Err(ProxyError::TgHandshakeTimeout); - } - - info!(%addr, "RPC handshake OK"); + let (stream, _connect_ms) = self.connect_tcp(addr).await?; + let hs = self.handshake_only(stream, addr, rng).await?; let rpc_w = Arc::new(Mutex::new(RpcWriter { - writer: wr, - key: wk, - iv: write_iv, + writer: hs.wr, + key: hs.write_key, + iv: hs.write_iv, seq_no: 0, })); self.writers.write().await.push((addr, rpc_w.clone())); @@ -554,7 +239,7 @@ impl MePool { let w_pool_ping = self.writers_arc(); tokio::spawn(async move { if let Err(e) = - reader_loop(rd, rk, read_iv, reg, enc_buf, dec_buf, w_pong.clone()).await + reader_loop(hs.rd, hs.read_key, hs.read_iv, reg, BytesMut::new(), BytesMut::new(), w_pong.clone()).await { warn!(error = %e, "ME reader ended"); }