diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 041e7cb..87d6b52 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -1,6 +1,8 @@ //! Client Handler +use std::future::Future; use std::net::SocketAddr; +use std::pin::Pin; use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; @@ -8,6 +10,17 @@ use tokio::net::TcpStream; use tokio::time::timeout; use tracing::{debug, warn}; +/// Post-handshake future (relay phase, runs outside handshake timeout) +type PostHandshakeFuture = Pin> + Send>>; + +/// Result of the handshake phase +enum HandshakeOutcome { + /// Handshake succeeded, relay work to do (outside timeout) + NeedsRelay(PostHandshakeFuture), + /// Already fully handled (bad client masking, etc.) + Handled, +} + use crate::config::ProxyConfig; use crate::crypto::SecureRandom; use crate::error::{HandshakeResult, ProxyError, Result}; @@ -24,6 +37,160 @@ use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle use crate::proxy::masking::handle_bad_client; use crate::proxy::middle_relay::handle_via_middle_proxy; +pub async fn handle_client_stream( + mut stream: S, + peer: SocketAddr, + config: Arc, + stats: Arc, + upstream_manager: Arc, + replay_checker: Arc, + buffer_pool: Arc, + rng: Arc, + me_pool: Option>, + ip_tracker: Arc, +) -> Result<()> +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + stats.increment_connects_all(); + debug!(peer = %peer, "New connection (generic stream)"); + + let handshake_timeout = Duration::from_secs(config.timeouts.client_handshake); + let stats_for_timeout = stats.clone(); + + // For non-TCP streams, use a synthetic local address + let local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port) + .parse() + .unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap()); + + // Phase 1: handshake (with timeout) + let outcome = match timeout(handshake_timeout, async { + let mut first_bytes = [0u8; 5]; + stream.read_exact(&mut first_bytes).await?; + + let is_tls = tls::is_tls_handshake(&first_bytes[..3]); + debug!(peer = %peer, is_tls = is_tls, "Handshake type detected"); + + if is_tls { + let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; + + if tls_len < 512 { + debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); + stats.increment_connects_bad(); + let (reader, writer) = tokio::io::split(stream); + handle_bad_client(reader, writer, &first_bytes, &config).await; + return Ok(HandshakeOutcome::Handled); + } + + let mut handshake = vec![0u8; 5 + tls_len]; + handshake[..5].copy_from_slice(&first_bytes); + stream.read_exact(&mut handshake[5..]).await?; + + let (read_half, write_half) = tokio::io::split(stream); + + let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( + &handshake, read_half, write_half, peer, + &config, &replay_checker, &rng, + ).await { + HandshakeResult::Success(result) => result, + HandshakeResult::BadClient { reader, writer } => { + stats.increment_connects_bad(); + handle_bad_client(reader, writer, &handshake, &config).await; + return Ok(HandshakeOutcome::Handled); + } + HandshakeResult::Error(e) => return Err(e), + }; + + debug!(peer = %peer, "Reading MTProto handshake through TLS"); + let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?; + let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into() + .map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?; + + let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( + &mtproto_handshake, tls_reader, tls_writer, peer, + &config, &replay_checker, true, + ).await { + HandshakeResult::Success(result) => result, + HandshakeResult::BadClient { reader: _, writer: _ } => { + stats.increment_connects_bad(); + debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); + return Ok(HandshakeOutcome::Handled); + } + HandshakeResult::Error(e) => return Err(e), + }; + + Ok(HandshakeOutcome::NeedsRelay(Box::pin( + RunningClientHandler::handle_authenticated_static( + crypto_reader, crypto_writer, success, + upstream_manager, stats, config, buffer_pool, rng, me_pool, + local_addr, peer, ip_tracker.clone(), + ), + ))) + } else { + if !config.general.modes.classic && !config.general.modes.secure { + debug!(peer = %peer, "Non-TLS modes disabled"); + stats.increment_connects_bad(); + let (reader, writer) = tokio::io::split(stream); + handle_bad_client(reader, writer, &first_bytes, &config).await; + return Ok(HandshakeOutcome::Handled); + } + + let mut handshake = [0u8; HANDSHAKE_LEN]; + handshake[..5].copy_from_slice(&first_bytes); + stream.read_exact(&mut handshake[5..]).await?; + + let (read_half, write_half) = tokio::io::split(stream); + + let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( + &handshake, read_half, write_half, peer, + &config, &replay_checker, false, + ).await { + HandshakeResult::Success(result) => result, + HandshakeResult::BadClient { reader, writer } => { + stats.increment_connects_bad(); + handle_bad_client(reader, writer, &handshake, &config).await; + return Ok(HandshakeOutcome::Handled); + } + HandshakeResult::Error(e) => return Err(e), + }; + + Ok(HandshakeOutcome::NeedsRelay(Box::pin( + RunningClientHandler::handle_authenticated_static( + crypto_reader, + crypto_writer, + success, + upstream_manager, + stats, + config, + buffer_pool, + rng, + me_pool, + local_addr, + peer, + ip_tracker.clone(), + ) + ))) + } + }).await { + Ok(Ok(outcome)) => outcome, + Ok(Err(e)) => { + debug!(peer = %peer, error = %e, "Handshake failed"); + return Err(e); + } + Err(_) => { + stats_for_timeout.increment_handshake_timeouts(); + debug!(peer = %peer, "Handshake timeout"); + return Err(ProxyError::TgHandshakeTimeout); + } + }; + + // Phase 2: relay (WITHOUT handshake timeout — relay has its own activity timeouts) + match outcome { + HandshakeOutcome::NeedsRelay(fut) => fut.await, + HandshakeOutcome::Handled => Ok(()), + } +} + pub struct ClientHandler; pub struct RunningClientHandler { @@ -72,6 +239,7 @@ impl RunningClientHandler { self.stats.increment_connects_all(); let peer = self.peer; + let ip_tracker = self.ip_tracker.clone(); debug!(peer = %peer, "New connection"); if let Err(e) = configure_client_socket( @@ -85,31 +253,34 @@ impl RunningClientHandler { let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake); let stats = self.stats.clone(); - let result = timeout(handshake_timeout, self.do_handshake()).await; - - match result { - Ok(Ok(())) => { - debug!(peer = %peer, "Connection handled successfully"); - Ok(()) - } + // Phase 1: handshake (with timeout) + let outcome = match timeout(handshake_timeout, self.do_handshake()).await { + Ok(Ok(outcome)) => outcome, Ok(Err(e)) => { debug!(peer = %peer, error = %e, "Handshake failed"); - Err(e) + return Err(e); } Err(_) => { stats.increment_handshake_timeouts(); debug!(peer = %peer, "Handshake timeout"); - Err(ProxyError::TgHandshakeTimeout) + return Err(ProxyError::TgHandshakeTimeout); } + }; + + // Phase 2: relay (WITHOUT handshake timeout — relay has its own activity timeouts) + match outcome { + HandshakeOutcome::NeedsRelay(fut) => fut.await, + HandshakeOutcome::Handled => Ok(()), } } - async fn do_handshake(mut self) -> Result<()> { + async fn do_handshake(mut self) -> Result { let mut first_bytes = [0u8; 5]; self.stream.read_exact(&mut first_bytes).await?; let is_tls = tls::is_tls_handshake(&first_bytes[..3]); let peer = self.peer; + let ip_tracker = self.ip_tracker.clone(); debug!(peer = %peer, is_tls = is_tls, "Handshake type detected"); @@ -120,8 +291,9 @@ impl RunningClientHandler { } } - async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> { + async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result { let peer = self.peer; + let ip_tracker = self.ip_tracker.clone(); let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; @@ -132,7 +304,7 @@ impl RunningClientHandler { self.stats.increment_connects_bad(); let (reader, writer) = self.stream.into_split(); handle_bad_client(reader, writer, &first_bytes, &self.config).await; - return Ok(()); + return Ok(HandshakeOutcome::Handled); } let mut handshake = vec![0u8; 5 + tls_len]; @@ -162,7 +334,7 @@ impl RunningClientHandler { HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); handle_bad_client(reader, writer, &handshake, &config).await; - return Ok(()); + return Ok(HandshakeOutcome::Handled); } HandshakeResult::Error(e) => return Err(e), }; @@ -191,37 +363,39 @@ impl RunningClientHandler { } => { stats.increment_connects_bad(); debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); - return Ok(()); + return Ok(HandshakeOutcome::Handled); } HandshakeResult::Error(e) => return Err(e), }; - Self::handle_authenticated_static( - crypto_reader, - crypto_writer, - success, - self.upstream_manager, - self.stats, - self.config, - buffer_pool, - self.rng, - self.me_pool, - local_addr, - peer, - self.ip_tracker, - ) - .await + Ok(HandshakeOutcome::NeedsRelay(Box::pin( + Self::handle_authenticated_static( + crypto_reader, + crypto_writer, + success, + self.upstream_manager, + self.stats, + self.config, + buffer_pool, + self.rng, + self.me_pool, + local_addr, + peer, + self.ip_tracker, + ), + ))) } - async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> { + async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result { let peer = self.peer; + let ip_tracker = self.ip_tracker.clone(); if !self.config.general.modes.classic && !self.config.general.modes.secure { debug!(peer = %peer, "Non-TLS modes disabled"); self.stats.increment_connects_bad(); let (reader, writer) = self.stream.into_split(); handle_bad_client(reader, writer, &first_bytes, &self.config).await; - return Ok(()); + return Ok(HandshakeOutcome::Handled); } let mut handshake = [0u8; HANDSHAKE_LEN]; @@ -251,26 +425,27 @@ impl RunningClientHandler { HandshakeResult::BadClient { reader, writer } => { stats.increment_connects_bad(); handle_bad_client(reader, writer, &handshake, &config).await; - return Ok(()); + return Ok(HandshakeOutcome::Handled); } HandshakeResult::Error(e) => return Err(e), }; - Self::handle_authenticated_static( - crypto_reader, - crypto_writer, - success, - self.upstream_manager, - self.stats, - self.config, - buffer_pool, - self.rng, - self.me_pool, - local_addr, - peer, - self.ip_tracker, - ) - .await + Ok(HandshakeOutcome::NeedsRelay(Box::pin( + Self::handle_authenticated_static( + crypto_reader, + crypto_writer, + success, + self.upstream_manager, + self.stats, + self.config, + buffer_pool, + self.rng, + self.me_pool, + local_addr, + peer, + self.ip_tracker, + ), + ))) } /// Main dispatch after successful handshake.