Refactor client handshake handling for clarity
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
//! Client Handler
|
//! Client Handler
|
||||||
|
|
||||||
|
use std::future::Future;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
||||||
@@ -8,6 +10,17 @@ use tokio::net::TcpStream;
|
|||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
use tracing::{debug, warn};
|
use tracing::{debug, warn};
|
||||||
|
|
||||||
|
/// Post-handshake future (relay phase, runs outside handshake timeout)
|
||||||
|
type PostHandshakeFuture = Pin<Box<dyn Future<Output = Result<()>> + Send>>;
|
||||||
|
|
||||||
|
/// Result of the handshake phase
|
||||||
|
enum HandshakeOutcome {
|
||||||
|
/// Handshake succeeded, relay work to do (outside timeout)
|
||||||
|
NeedsRelay(PostHandshakeFuture),
|
||||||
|
/// Already fully handled (bad client masking, etc.)
|
||||||
|
Handled,
|
||||||
|
}
|
||||||
|
|
||||||
use crate::config::ProxyConfig;
|
use crate::config::ProxyConfig;
|
||||||
use crate::crypto::SecureRandom;
|
use crate::crypto::SecureRandom;
|
||||||
use crate::error::{HandshakeResult, ProxyError, Result};
|
use crate::error::{HandshakeResult, ProxyError, Result};
|
||||||
@@ -24,6 +37,160 @@ use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle
|
|||||||
use crate::proxy::masking::handle_bad_client;
|
use crate::proxy::masking::handle_bad_client;
|
||||||
use crate::proxy::middle_relay::handle_via_middle_proxy;
|
use crate::proxy::middle_relay::handle_via_middle_proxy;
|
||||||
|
|
||||||
|
pub async fn handle_client_stream<S>(
|
||||||
|
mut stream: S,
|
||||||
|
peer: SocketAddr,
|
||||||
|
config: Arc<ProxyConfig>,
|
||||||
|
stats: Arc<Stats>,
|
||||||
|
upstream_manager: Arc<UpstreamManager>,
|
||||||
|
replay_checker: Arc<ReplayChecker>,
|
||||||
|
buffer_pool: Arc<BufferPool>,
|
||||||
|
rng: Arc<SecureRandom>,
|
||||||
|
me_pool: Option<Arc<MePool>>,
|
||||||
|
ip_tracker: Arc<UserIpTracker>,
|
||||||
|
) -> Result<()>
|
||||||
|
where
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
stats.increment_connects_all();
|
||||||
|
debug!(peer = %peer, "New connection (generic stream)");
|
||||||
|
|
||||||
|
let handshake_timeout = Duration::from_secs(config.timeouts.client_handshake);
|
||||||
|
let stats_for_timeout = stats.clone();
|
||||||
|
|
||||||
|
// For non-TCP streams, use a synthetic local address
|
||||||
|
let local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port)
|
||||||
|
.parse()
|
||||||
|
.unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap());
|
||||||
|
|
||||||
|
// Phase 1: handshake (with timeout)
|
||||||
|
let outcome = match timeout(handshake_timeout, async {
|
||||||
|
let mut first_bytes = [0u8; 5];
|
||||||
|
stream.read_exact(&mut first_bytes).await?;
|
||||||
|
|
||||||
|
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
|
||||||
|
debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
|
||||||
|
|
||||||
|
if is_tls {
|
||||||
|
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
|
||||||
|
|
||||||
|
if tls_len < 512 {
|
||||||
|
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
|
||||||
|
stats.increment_connects_bad();
|
||||||
|
let (reader, writer) = tokio::io::split(stream);
|
||||||
|
handle_bad_client(reader, writer, &first_bytes, &config).await;
|
||||||
|
return Ok(HandshakeOutcome::Handled);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut handshake = vec![0u8; 5 + tls_len];
|
||||||
|
handshake[..5].copy_from_slice(&first_bytes);
|
||||||
|
stream.read_exact(&mut handshake[5..]).await?;
|
||||||
|
|
||||||
|
let (read_half, write_half) = tokio::io::split(stream);
|
||||||
|
|
||||||
|
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
|
||||||
|
&handshake, read_half, write_half, peer,
|
||||||
|
&config, &replay_checker, &rng,
|
||||||
|
).await {
|
||||||
|
HandshakeResult::Success(result) => result,
|
||||||
|
HandshakeResult::BadClient { reader, writer } => {
|
||||||
|
stats.increment_connects_bad();
|
||||||
|
handle_bad_client(reader, writer, &handshake, &config).await;
|
||||||
|
return Ok(HandshakeOutcome::Handled);
|
||||||
|
}
|
||||||
|
HandshakeResult::Error(e) => return Err(e),
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!(peer = %peer, "Reading MTProto handshake through TLS");
|
||||||
|
let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?;
|
||||||
|
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into()
|
||||||
|
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
|
||||||
|
|
||||||
|
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
||||||
|
&mtproto_handshake, tls_reader, tls_writer, peer,
|
||||||
|
&config, &replay_checker, true,
|
||||||
|
).await {
|
||||||
|
HandshakeResult::Success(result) => result,
|
||||||
|
HandshakeResult::BadClient { reader: _, writer: _ } => {
|
||||||
|
stats.increment_connects_bad();
|
||||||
|
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
|
||||||
|
return Ok(HandshakeOutcome::Handled);
|
||||||
|
}
|
||||||
|
HandshakeResult::Error(e) => return Err(e),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(HandshakeOutcome::NeedsRelay(Box::pin(
|
||||||
|
RunningClientHandler::handle_authenticated_static(
|
||||||
|
crypto_reader, crypto_writer, success,
|
||||||
|
upstream_manager, stats, config, buffer_pool, rng, me_pool,
|
||||||
|
local_addr, peer, ip_tracker.clone(),
|
||||||
|
),
|
||||||
|
)))
|
||||||
|
} else {
|
||||||
|
if !config.general.modes.classic && !config.general.modes.secure {
|
||||||
|
debug!(peer = %peer, "Non-TLS modes disabled");
|
||||||
|
stats.increment_connects_bad();
|
||||||
|
let (reader, writer) = tokio::io::split(stream);
|
||||||
|
handle_bad_client(reader, writer, &first_bytes, &config).await;
|
||||||
|
return Ok(HandshakeOutcome::Handled);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut handshake = [0u8; HANDSHAKE_LEN];
|
||||||
|
handshake[..5].copy_from_slice(&first_bytes);
|
||||||
|
stream.read_exact(&mut handshake[5..]).await?;
|
||||||
|
|
||||||
|
let (read_half, write_half) = tokio::io::split(stream);
|
||||||
|
|
||||||
|
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
||||||
|
&handshake, read_half, write_half, peer,
|
||||||
|
&config, &replay_checker, false,
|
||||||
|
).await {
|
||||||
|
HandshakeResult::Success(result) => result,
|
||||||
|
HandshakeResult::BadClient { reader, writer } => {
|
||||||
|
stats.increment_connects_bad();
|
||||||
|
handle_bad_client(reader, writer, &handshake, &config).await;
|
||||||
|
return Ok(HandshakeOutcome::Handled);
|
||||||
|
}
|
||||||
|
HandshakeResult::Error(e) => return Err(e),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(HandshakeOutcome::NeedsRelay(Box::pin(
|
||||||
|
RunningClientHandler::handle_authenticated_static(
|
||||||
|
crypto_reader,
|
||||||
|
crypto_writer,
|
||||||
|
success,
|
||||||
|
upstream_manager,
|
||||||
|
stats,
|
||||||
|
config,
|
||||||
|
buffer_pool,
|
||||||
|
rng,
|
||||||
|
me_pool,
|
||||||
|
local_addr,
|
||||||
|
peer,
|
||||||
|
ip_tracker.clone(),
|
||||||
|
)
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}).await {
|
||||||
|
Ok(Ok(outcome)) => outcome,
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
debug!(peer = %peer, error = %e, "Handshake failed");
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
stats_for_timeout.increment_handshake_timeouts();
|
||||||
|
debug!(peer = %peer, "Handshake timeout");
|
||||||
|
return Err(ProxyError::TgHandshakeTimeout);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Phase 2: relay (WITHOUT handshake timeout — relay has its own activity timeouts)
|
||||||
|
match outcome {
|
||||||
|
HandshakeOutcome::NeedsRelay(fut) => fut.await,
|
||||||
|
HandshakeOutcome::Handled => Ok(()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct ClientHandler;
|
pub struct ClientHandler;
|
||||||
|
|
||||||
pub struct RunningClientHandler {
|
pub struct RunningClientHandler {
|
||||||
@@ -72,6 +239,7 @@ impl RunningClientHandler {
|
|||||||
self.stats.increment_connects_all();
|
self.stats.increment_connects_all();
|
||||||
|
|
||||||
let peer = self.peer;
|
let peer = self.peer;
|
||||||
|
let ip_tracker = self.ip_tracker.clone();
|
||||||
debug!(peer = %peer, "New connection");
|
debug!(peer = %peer, "New connection");
|
||||||
|
|
||||||
if let Err(e) = configure_client_socket(
|
if let Err(e) = configure_client_socket(
|
||||||
@@ -85,31 +253,34 @@ impl RunningClientHandler {
|
|||||||
let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake);
|
let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake);
|
||||||
let stats = self.stats.clone();
|
let stats = self.stats.clone();
|
||||||
|
|
||||||
let result = timeout(handshake_timeout, self.do_handshake()).await;
|
// Phase 1: handshake (with timeout)
|
||||||
|
let outcome = match timeout(handshake_timeout, self.do_handshake()).await {
|
||||||
match result {
|
Ok(Ok(outcome)) => outcome,
|
||||||
Ok(Ok(())) => {
|
|
||||||
debug!(peer = %peer, "Connection handled successfully");
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
Ok(Err(e)) => {
|
Ok(Err(e)) => {
|
||||||
debug!(peer = %peer, error = %e, "Handshake failed");
|
debug!(peer = %peer, error = %e, "Handshake failed");
|
||||||
Err(e)
|
return Err(e);
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
stats.increment_handshake_timeouts();
|
stats.increment_handshake_timeouts();
|
||||||
debug!(peer = %peer, "Handshake timeout");
|
debug!(peer = %peer, "Handshake timeout");
|
||||||
Err(ProxyError::TgHandshakeTimeout)
|
return Err(ProxyError::TgHandshakeTimeout);
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Phase 2: relay (WITHOUT handshake timeout — relay has its own activity timeouts)
|
||||||
|
match outcome {
|
||||||
|
HandshakeOutcome::NeedsRelay(fut) => fut.await,
|
||||||
|
HandshakeOutcome::Handled => Ok(()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn do_handshake(mut self) -> Result<()> {
|
async fn do_handshake(mut self) -> Result<HandshakeOutcome> {
|
||||||
let mut first_bytes = [0u8; 5];
|
let mut first_bytes = [0u8; 5];
|
||||||
self.stream.read_exact(&mut first_bytes).await?;
|
self.stream.read_exact(&mut first_bytes).await?;
|
||||||
|
|
||||||
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
|
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
|
||||||
let peer = self.peer;
|
let peer = self.peer;
|
||||||
|
let ip_tracker = self.ip_tracker.clone();
|
||||||
|
|
||||||
debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
|
debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
|
||||||
|
|
||||||
@@ -120,8 +291,9 @@ impl RunningClientHandler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
|
async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<HandshakeOutcome> {
|
||||||
let peer = self.peer;
|
let peer = self.peer;
|
||||||
|
let ip_tracker = self.ip_tracker.clone();
|
||||||
|
|
||||||
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
|
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
|
||||||
|
|
||||||
@@ -132,7 +304,7 @@ impl RunningClientHandler {
|
|||||||
self.stats.increment_connects_bad();
|
self.stats.increment_connects_bad();
|
||||||
let (reader, writer) = self.stream.into_split();
|
let (reader, writer) = self.stream.into_split();
|
||||||
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
|
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
|
||||||
return Ok(());
|
return Ok(HandshakeOutcome::Handled);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut handshake = vec![0u8; 5 + tls_len];
|
let mut handshake = vec![0u8; 5 + tls_len];
|
||||||
@@ -162,7 +334,7 @@ impl RunningClientHandler {
|
|||||||
HandshakeResult::BadClient { reader, writer } => {
|
HandshakeResult::BadClient { reader, writer } => {
|
||||||
stats.increment_connects_bad();
|
stats.increment_connects_bad();
|
||||||
handle_bad_client(reader, writer, &handshake, &config).await;
|
handle_bad_client(reader, writer, &handshake, &config).await;
|
||||||
return Ok(());
|
return Ok(HandshakeOutcome::Handled);
|
||||||
}
|
}
|
||||||
HandshakeResult::Error(e) => return Err(e),
|
HandshakeResult::Error(e) => return Err(e),
|
||||||
};
|
};
|
||||||
@@ -191,11 +363,12 @@ impl RunningClientHandler {
|
|||||||
} => {
|
} => {
|
||||||
stats.increment_connects_bad();
|
stats.increment_connects_bad();
|
||||||
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
|
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
|
||||||
return Ok(());
|
return Ok(HandshakeOutcome::Handled);
|
||||||
}
|
}
|
||||||
HandshakeResult::Error(e) => return Err(e),
|
HandshakeResult::Error(e) => return Err(e),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Ok(HandshakeOutcome::NeedsRelay(Box::pin(
|
||||||
Self::handle_authenticated_static(
|
Self::handle_authenticated_static(
|
||||||
crypto_reader,
|
crypto_reader,
|
||||||
crypto_writer,
|
crypto_writer,
|
||||||
@@ -209,19 +382,20 @@ impl RunningClientHandler {
|
|||||||
local_addr,
|
local_addr,
|
||||||
peer,
|
peer,
|
||||||
self.ip_tracker,
|
self.ip_tracker,
|
||||||
)
|
),
|
||||||
.await
|
)))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
|
async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<HandshakeOutcome> {
|
||||||
let peer = self.peer;
|
let peer = self.peer;
|
||||||
|
let ip_tracker = self.ip_tracker.clone();
|
||||||
|
|
||||||
if !self.config.general.modes.classic && !self.config.general.modes.secure {
|
if !self.config.general.modes.classic && !self.config.general.modes.secure {
|
||||||
debug!(peer = %peer, "Non-TLS modes disabled");
|
debug!(peer = %peer, "Non-TLS modes disabled");
|
||||||
self.stats.increment_connects_bad();
|
self.stats.increment_connects_bad();
|
||||||
let (reader, writer) = self.stream.into_split();
|
let (reader, writer) = self.stream.into_split();
|
||||||
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
|
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
|
||||||
return Ok(());
|
return Ok(HandshakeOutcome::Handled);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut handshake = [0u8; HANDSHAKE_LEN];
|
let mut handshake = [0u8; HANDSHAKE_LEN];
|
||||||
@@ -251,11 +425,12 @@ impl RunningClientHandler {
|
|||||||
HandshakeResult::BadClient { reader, writer } => {
|
HandshakeResult::BadClient { reader, writer } => {
|
||||||
stats.increment_connects_bad();
|
stats.increment_connects_bad();
|
||||||
handle_bad_client(reader, writer, &handshake, &config).await;
|
handle_bad_client(reader, writer, &handshake, &config).await;
|
||||||
return Ok(());
|
return Ok(HandshakeOutcome::Handled);
|
||||||
}
|
}
|
||||||
HandshakeResult::Error(e) => return Err(e),
|
HandshakeResult::Error(e) => return Err(e),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Ok(HandshakeOutcome::NeedsRelay(Box::pin(
|
||||||
Self::handle_authenticated_static(
|
Self::handle_authenticated_static(
|
||||||
crypto_reader,
|
crypto_reader,
|
||||||
crypto_writer,
|
crypto_writer,
|
||||||
@@ -269,8 +444,8 @@ impl RunningClientHandler {
|
|||||||
local_addr,
|
local_addr,
|
||||||
peer,
|
peer,
|
||||||
self.ip_tracker,
|
self.ip_tracker,
|
||||||
)
|
),
|
||||||
.await
|
)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Main dispatch after successful handshake.
|
/// Main dispatch after successful handshake.
|
||||||
|
|||||||
Reference in New Issue
Block a user