diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 92dd373..d6243aa 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -1,11 +1,13 @@ //! Proxy Defs -pub mod handshake; pub mod client; -pub mod relay; +pub mod direct_relay; +pub mod handshake; pub mod masking; +pub mod middle_relay; +pub mod relay; -pub use handshake::*; pub use client::ClientHandler; +pub use handshake::*; +pub use masking::*; pub use relay::*; -pub use masking::*; \ No newline at end of file diff --git a/src/transport/middle_proxy/codec.rs b/src/transport/middle_proxy/codec.rs index 4eaaa4c..51daee9 100644 --- a/src/transport/middle_proxy/codec.rs +++ b/src/transport/middle_proxy/codec.rs @@ -59,7 +59,7 @@ pub(crate) fn build_nonce_payload(key_selector: u32, crypto_ts: u32, nonce: &[u8 p } -pub(crate) fn parse_nonce_payload(d: &[u8]) -> Result<(u32, u32, [u8; 16])> { +pub(crate) fn parse_nonce_payload(d: &[u8]) -> Result<(u32, u32, u32, [u8; 16])> { if d.len() < 32 { return Err(ProxyError::InvalidHandshake(format!( "Nonce payload too short: {} bytes", @@ -74,11 +74,12 @@ pub(crate) fn parse_nonce_payload(d: &[u8]) -> Result<(u32, u32, [u8; 16])> { ))); } + let key_select = u32::from_le_bytes(d[4..8].try_into().unwrap()); let schema = u32::from_le_bytes(d[8..12].try_into().unwrap()); let ts = u32::from_le_bytes(d[12..16].try_into().unwrap()); let mut nonce = [0u8; 16]; nonce.copy_from_slice(&d[16..32]); - Ok((schema, ts, nonce)) + Ok((key_select, schema, ts, nonce)) } pub(crate) fn build_handshake_payload( diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs index 6cb38cb..4906c4b 100644 --- a/src/transport/middle_proxy/mod.rs +++ b/src/transport/middle_proxy/mod.rs @@ -3,6 +3,7 @@ mod codec; mod health; mod pool; +mod pool_nat; mod reader; mod registry; mod send; diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index c59df36..9a978cc 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -1,16 +1,18 @@ use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; +use std::sync::OnceLock; use std::sync::atomic::AtomicU64; use std::time::Duration; use bytes::BytesMut; +use rand::Rng; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; use tokio::sync::{Mutex, RwLock}; use tokio::time::{Instant, timeout}; use tracing::{debug, info, warn}; -use crate::crypto::{SecureRandom, derive_middleproxy_keys}; +use crate::crypto::{SecureRandom, derive_middleproxy_keys, sha256}; use crate::error::{ProxyError, Result}; use crate::protocol::constants::*; @@ -23,6 +25,7 @@ use super::reader::reader_loop; use super::wire::{IpMaterial, extract_ip_material}; const ME_ACTIVE_PING_SECS: u64 = 25; +const ME_ACTIVE_PING_JITTER_SECS: i64 = 5; pub struct MePool { pub(super) registry: Arc, @@ -30,7 +33,8 @@ pub struct MePool { pub(super) rr: AtomicU64, pub(super) proxy_tag: Option>, proxy_secret: Vec, - nat_ip: Option, + pub(super) nat_ip_cfg: Option, + pub(super) nat_ip_detected: OnceLock, pool_size: usize, } @@ -46,7 +50,8 @@ impl MePool { rr: AtomicU64::new(0), proxy_tag, proxy_secret, - nat_ip, + nat_ip_cfg: nat_ip, + nat_ip_detected: OnceLock::new(), pool_size: 2, }) } @@ -64,24 +69,6 @@ impl MePool { &self.registry } - fn translate_ip_for_nat(&self, ip: IpAddr) -> IpAddr { - let Some(nat_ip) = self.nat_ip else { - return ip; - }; - - match (ip, nat_ip) { - (IpAddr::V4(src), IpAddr::V4(dst)) - if src.is_private() || src.is_loopback() || src.is_unspecified() => - { - IpAddr::V4(dst) - } - (IpAddr::V6(src), IpAddr::V6(dst)) if src.is_loopback() || src.is_unspecified() => { - IpAddr::V6(dst) - } - (orig, _) => orig, - } - } - fn writers_arc(&self) -> Arc>)>>> { self.writers.clone() @@ -155,6 +142,7 @@ impl MePool { let local_addr = stream.local_addr().map_err(ProxyError::Io)?; let peer_addr = stream.peer_addr().map_err(ProxyError::Io)?; + let _ = self.maybe_detect_nat_ip(local_addr.ip()).await; let local_addr_nat = self.translate_our_addr(local_addr); let peer_addr_nat = SocketAddr::new(self.translate_ip_for_nat(peer_addr.ip()), peer_addr.port()); @@ -169,6 +157,14 @@ impl MePool { let ks = self.key_selector(); let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce); let nonce_frame = build_rpc_frame(-2, &nonce_payload); + let dump = hex_dump(&nonce_frame[..nonce_frame.len().min(44)]); + info!( + key_selector = format_args!("0x{ks:08x}"), + crypto_ts, + frame_len = nonce_frame.len(), + nonce_frame_hex = %dump, + "Sending ME nonce frame" + ); wr.write_all(&nonce_frame).await.map_err(ProxyError::Io)?; wr.flush().await.map_err(ProxyError::Io)?; @@ -185,13 +181,20 @@ impl MePool { ))); } - let (schema, srv_ts, srv_nonce) = parse_nonce_payload(&srv_nonce_payload)?; + let (srv_key_select, schema, srv_ts, srv_nonce) = parse_nonce_payload(&srv_nonce_payload)?; if schema != RPC_CRYPTO_AES_U32 { + warn!(schema = format_args!("0x{schema:08x}"), "Unsupported ME crypto schema"); return Err(ProxyError::InvalidHandshake(format!( "Unsupported crypto schema: 0x{schema:x}" ))); } + if srv_key_select != ks { + return Err(ProxyError::InvalidHandshake(format!( + "Server key_select 0x{srv_key_select:08x} != client 0x{ks:08x}" + ))); + } + let skew = crypto_ts.abs_diff(srv_ts); if skew > 30 { return Err(ProxyError::InvalidHandshake(format!( @@ -199,6 +202,17 @@ impl MePool { ))); } + info!( + %local_addr, + %local_addr_nat, + %peer_addr, + %peer_addr_nat, + key_selector = format_args!("0x{ks:08x}"), + crypto_schema = format_args!("0x{schema:08x}"), + skew_secs = skew, + "ME key derivation parameters" + ); + let ts_bytes = crypto_ts.to_le_bytes(); let server_port_bytes = peer_addr_nat.port().to_le_bytes(); let client_port_bytes = local_addr_nat.port().to_le_bytes(); @@ -250,10 +264,29 @@ impl MePool { srv_v6_opt.as_ref(), ); + let diag = std::env::var("ME_DIAG").map(|v| v == "1").unwrap_or(false); let hs_payload = build_handshake_payload(hs_our_ip, local_addr.port(), hs_peer_ip, peer_addr.port()); let hs_frame = build_rpc_frame(-1, &hs_payload); + if diag { + info!( + write_key = %hex_dump(&wk), + write_iv = %hex_dump(&wi), + read_key = %hex_dump(&rk), + read_iv = %hex_dump(&ri), + hs_plain = %hex_dump(&hs_frame), + proxy_secret_sha256 = %hex_dump(&sha256(secret)), + "ME diag: derived keys and handshake plaintext" + ); + } + let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?; + if diag { + info!( + hs_cipher = %hex_dump(&encrypted_hs), + "ME diag: handshake ciphertext" + ); + } wr.write_all(&encrypted_hs).await.map_err(ProxyError::Io)?; wr.flush().await.map_err(ProxyError::Io)?; @@ -369,7 +402,10 @@ impl MePool { tokio::spawn(async move { let mut ping_id: i64 = rand::random::(); loop { - tokio::time::sleep(Duration::from_secs(ME_ACTIVE_PING_SECS)).await; + let jitter = rand::rng() + .random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); + let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; + tokio::time::sleep(Duration::from_secs(wait)).await; let mut p = Vec::with_capacity(12); p.extend_from_slice(&RPC_PING_U32.to_le_bytes()); p.extend_from_slice(&ping_id.to_le_bytes()); @@ -387,3 +423,18 @@ impl MePool { } } + +fn hex_dump(data: &[u8]) -> String { + const MAX: usize = 64; + let mut out = String::with_capacity(data.len() * 2 + 3); + for (i, b) in data.iter().take(MAX).enumerate() { + if i > 0 { + out.push(' '); + } + out.push_str(&format!("{b:02x}")); + } + if data.len() > MAX { + out.push_str(" …"); + } + out +} diff --git a/src/transport/middle_proxy/pool_nat.rs b/src/transport/middle_proxy/pool_nat.rs new file mode 100644 index 0000000..6891919 --- /dev/null +++ b/src/transport/middle_proxy/pool_nat.rs @@ -0,0 +1,80 @@ +use std::net::{IpAddr, Ipv4Addr}; + +use tracing::{info, warn}; + +use crate::error::{ProxyError, Result}; + +use super::MePool; + +impl MePool { + pub(super) fn translate_ip_for_nat(&self, ip: IpAddr) -> IpAddr { + let nat_ip = self + .nat_ip_cfg + .or_else(|| self.nat_ip_detected.get().copied()); + + let Some(nat_ip) = nat_ip else { + return ip; + }; + + match (ip, nat_ip) { + (IpAddr::V4(src), IpAddr::V4(dst)) + if is_privateish(IpAddr::V4(src)) + || src.is_loopback() + || src.is_unspecified() => + { + IpAddr::V4(dst) + } + (IpAddr::V6(src), IpAddr::V6(dst)) if src.is_loopback() || src.is_unspecified() => { + IpAddr::V6(dst) + } + (orig, _) => orig, + } + } + + pub(super) async fn maybe_detect_nat_ip(&self, local_ip: IpAddr) -> Option { + if self.nat_ip_cfg.is_some() { + return self.nat_ip_cfg; + } + + if !(is_privateish(local_ip) || local_ip.is_loopback() || local_ip.is_unspecified()) { + return None; + } + + if let Some(ip) = self.nat_ip_detected.get().copied() { + return Some(ip); + } + + match fetch_public_ipv4().await { + Ok(Some(ip)) => { + let _ = self.nat_ip_detected.set(IpAddr::V4(ip)); + info!(public_ip = %ip, "Auto-detected public IP for NAT translation"); + Some(IpAddr::V4(ip)) + } + Ok(None) => None, + Err(e) => { + warn!(error = %e, "Failed to auto-detect public IP"); + None + } + } + } +} + +async fn fetch_public_ipv4() -> Result> { + let res = reqwest::get("https://checkip.amazonaws.com").await.map_err(|e| { + ProxyError::Proxy(format!("public IP detection request failed: {e}")) + })?; + + let text = res.text().await.map_err(|e| { + ProxyError::Proxy(format!("public IP detection read failed: {e}")) + })?; + + let ip = text.trim().parse().ok(); + Ok(ip) +} + +fn is_privateish(ip: IpAddr) -> bool { + match ip { + IpAddr::V4(v4) => v4.is_private() || v4.is_link_local(), + IpAddr::V6(v6) => v6.is_unique_local(), + } +}