Drafting Upstreams and SOCKS
This commit is contained in:
Alexey
2026-01-07 17:22:10 +03:00
parent 7746a1177c
commit 4f007f3128
10 changed files with 839 additions and 244 deletions

View File

@@ -13,7 +13,7 @@ use crate::error::{ProxyError, Result, HandshakeResult};
use crate::protocol::constants::*;
use crate::protocol::tls;
use crate::stats::{Stats, ReplayChecker};
use crate::transport::{ConnectionPool, configure_client_socket};
use crate::transport::{configure_client_socket, UpstreamManager};
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter};
use crate::crypto::AesCtr;
@@ -24,39 +24,55 @@ use super::handshake::{
use super::relay::relay_bidirectional;
use super::masking::handle_bad_client;
/// Client connection handler
pub struct ClientHandler {
/// Client connection handler (builder struct)
pub struct ClientHandler;
/// Running client handler with stream and context
pub struct RunningClientHandler {
stream: TcpStream,
peer: SocketAddr,
config: Arc<ProxyConfig>,
stats: Arc<Stats>,
replay_checker: Arc<ReplayChecker>,
pool: Arc<ConnectionPool>,
upstream_manager: Arc<UpstreamManager>,
}
impl ClientHandler {
/// Create new client handler
/// Create new client handler instance
pub fn new(
stream: TcpStream,
peer: SocketAddr,
config: Arc<ProxyConfig>,
stats: Arc<Stats>,
replay_checker: Arc<ReplayChecker>,
pool: Arc<ConnectionPool>,
) -> Self {
Self {
upstream_manager: Arc<UpstreamManager>,
) -> RunningClientHandler {
// Note: ReplayChecker should be shared globally for proper replay protection
// Creating it per-connection disables replay protection across connections
// TODO: Pass Arc<ReplayChecker> from main.rs
let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len));
RunningClientHandler {
stream,
peer,
config,
stats,
replay_checker,
pool,
upstream_manager,
}
}
/// Handle a client connection
pub async fn handle(&self, stream: TcpStream, peer: SocketAddr) {
}
impl RunningClientHandler {
/// Run the client handler
pub async fn run(mut self) -> Result<()> {
self.stats.increment_connects_all();
let peer = self.peer;
debug!(peer = %peer, "New connection");
// Configure socket
if let Err(e) = configure_client_socket(
&stream,
&self.stream,
self.config.client_keepalive,
self.config.client_ack_timeout,
) {
@@ -66,49 +82,56 @@ impl ClientHandler {
// Perform handshake with timeout
let handshake_timeout = Duration::from_secs(self.config.client_handshake_timeout);
// Clone stats for error handling block
let stats = self.stats.clone();
let result = timeout(
handshake_timeout,
self.do_handshake(stream, peer)
self.do_handshake()
).await;
match result {
Ok(Ok(())) => {
debug!(peer = %peer, "Connection handled successfully");
Ok(())
}
Ok(Err(e)) => {
debug!(peer = %peer, error = %e, "Handshake failed");
Err(e)
}
Err(_) => {
self.stats.increment_handshake_timeouts();
stats.increment_handshake_timeouts();
debug!(peer = %peer, "Handshake timeout");
Err(ProxyError::TgHandshakeTimeout)
}
}
}
/// Perform handshake and relay
async fn do_handshake(&self, mut stream: TcpStream, peer: SocketAddr) -> Result<()> {
async fn do_handshake(mut self) -> Result<()> {
// Read first bytes to determine handshake type
let mut first_bytes = [0u8; 5];
stream.read_exact(&mut first_bytes).await?;
self.stream.read_exact(&mut first_bytes).await?;
let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
let peer = self.peer;
debug!(peer = %peer, is_tls = is_tls, first_bytes = %hex::encode(&first_bytes), "Handshake type detected");
if is_tls {
self.handle_tls_client(stream, peer, first_bytes).await
self.handle_tls_client(first_bytes).await
} else {
self.handle_direct_client(stream, peer, first_bytes).await
self.handle_direct_client(first_bytes).await
}
}
/// Handle TLS-wrapped client
async fn handle_tls_client(
&self,
mut stream: TcpStream,
peer: SocketAddr,
mut self,
first_bytes: [u8; 5],
) -> Result<()> {
let peer = self.peer;
// Read TLS handshake length
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
@@ -117,17 +140,22 @@ impl ClientHandler {
if tls_len < 512 {
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
self.stats.increment_connects_bad();
handle_bad_client(stream, &first_bytes, &self.config).await;
handle_bad_client(self.stream, &first_bytes, &self.config).await;
return Ok(());
}
// Read full TLS handshake
let mut handshake = vec![0u8; 5 + tls_len];
handshake[..5].copy_from_slice(&first_bytes);
stream.read_exact(&mut handshake[5..]).await?;
self.stream.read_exact(&mut handshake[5..]).await?;
// Extract fields before consuming self.stream
let config = self.config.clone();
let replay_checker = self.replay_checker.clone();
let stats = self.stats.clone();
// Split stream for reading/writing
let (read_half, write_half) = stream.into_split();
let (read_half, write_half) = self.stream.into_split();
// Handle TLS handshake
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
@@ -135,12 +163,12 @@ impl ClientHandler {
read_half,
write_half,
peer,
&self.config,
&self.replay_checker,
&config,
&replay_checker,
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient => {
self.stats.increment_connects_bad();
stats.increment_connects_bad();
return Ok(());
}
HandshakeResult::Error(e) => return Err(e),
@@ -158,44 +186,62 @@ impl ClientHandler {
tls_reader,
tls_writer,
peer,
&self.config,
&self.replay_checker,
&config,
&replay_checker,
true,
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient => {
self.stats.increment_connects_bad();
stats.increment_connects_bad();
return Ok(());
}
HandshakeResult::Error(e) => return Err(e),
};
// Handle authenticated client
self.handle_authenticated_inner(crypto_reader, crypto_writer, success).await
// We can't use self.handle_authenticated_inner because self is partially moved
// So we call it as an associated function or method on a new struct,
// or just inline the logic / use a static method.
// Since handle_authenticated_inner needs self.upstream_manager and self.stats,
// we should pass them explicitly.
Self::handle_authenticated_static(
crypto_reader,
crypto_writer,
success,
self.upstream_manager,
self.stats,
self.config
).await
}
/// Handle direct (non-TLS) client
async fn handle_direct_client(
&self,
mut stream: TcpStream,
peer: SocketAddr,
mut self,
first_bytes: [u8; 5],
) -> Result<()> {
let peer = self.peer;
// Check if non-TLS modes are enabled
if !self.config.modes.classic && !self.config.modes.secure {
debug!(peer = %peer, "Non-TLS modes disabled");
self.stats.increment_connects_bad();
handle_bad_client(stream, &first_bytes, &self.config).await;
handle_bad_client(self.stream, &first_bytes, &self.config).await;
return Ok(());
}
// Read rest of handshake
let mut handshake = [0u8; HANDSHAKE_LEN];
handshake[..5].copy_from_slice(&first_bytes);
stream.read_exact(&mut handshake[5..]).await?;
self.stream.read_exact(&mut handshake[5..]).await?;
// Extract fields
let config = self.config.clone();
let replay_checker = self.replay_checker.clone();
let stats = self.stats.clone();
// Split stream
let (read_half, write_half) = stream.into_split();
let (read_half, write_half) = self.stream.into_split();
// Handle MTProto handshake
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
@@ -203,27 +249,36 @@ impl ClientHandler {
read_half,
write_half,
peer,
&self.config,
&self.replay_checker,
&config,
&replay_checker,
false,
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient => {
self.stats.increment_connects_bad();
stats.increment_connects_bad();
return Ok(());
}
HandshakeResult::Error(e) => return Err(e),
};
self.handle_authenticated_inner(crypto_reader, crypto_writer, success).await
Self::handle_authenticated_static(
crypto_reader,
crypto_writer,
success,
self.upstream_manager,
self.stats,
self.config
).await
}
/// Handle authenticated client - connect to Telegram and relay
async fn handle_authenticated_inner<R, W>(
&self,
/// Static version of handle_authenticated_inner to avoid ownership issues
async fn handle_authenticated_static<R, W>(
client_reader: CryptoReader<R>,
client_writer: CryptoWriter<W>,
success: HandshakeSuccess,
upstream_manager: Arc<UpstreamManager>,
stats: Arc<Stats>,
config: Arc<ProxyConfig>,
) -> Result<()>
where
R: AsyncRead + Unpin + Send + 'static,
@@ -232,13 +287,13 @@ impl ClientHandler {
let user = &success.user;
// Check user limits
if let Err(e) = self.check_user_limits(user) {
if let Err(e) = Self::check_user_limits_static(user, &config, &stats) {
warn!(user = %user, error = %e, "User limit exceeded");
return Err(e);
}
// Get datacenter address
let dc_addr = self.get_dc_addr(success.dc_idx)?;
let dc_addr = Self::get_dc_addr_static(success.dc_idx, &config)?;
info!(
user = %user,
@@ -246,39 +301,40 @@ impl ClientHandler {
dc = success.dc_idx,
dc_addr = %dc_addr,
proto = ?success.proto_tag,
fast_mode = self.config.fast_mode,
fast_mode = config.fast_mode,
"Connecting to Telegram"
);
// Connect to Telegram
let tg_stream = self.pool.get(dc_addr).await?;
// Connect to Telegram via UpstreamManager
let tg_stream = upstream_manager.connect(dc_addr).await?;
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected to Telegram, performing handshake");
// Perform Telegram handshake and get crypto streams
let (tg_reader, tg_writer) = self.do_tg_handshake(
let (tg_reader, tg_writer) = Self::do_tg_handshake_static(
tg_stream,
&success,
&config,
).await?;
debug!(peer = %success.peer, "Telegram handshake complete, starting relay");
// Update stats
self.stats.increment_user_connects(user);
self.stats.increment_user_curr_connects(user);
stats.increment_user_connects(user);
stats.increment_user_curr_connects(user);
// Relay traffic - передаём Arc::clone(&self.stats)
// Relay traffic
let relay_result = relay_bidirectional(
client_reader,
client_writer,
tg_reader,
tg_writer,
user,
Arc::clone(&self.stats),
Arc::clone(&stats),
).await;
// Update stats
self.stats.decrement_user_curr_connects(user);
stats.decrement_user_curr_connects(user);
match &relay_result {
Ok(()) => debug!(user = %user, peer = %success.peer, "Relay completed normally"),
@@ -288,26 +344,26 @@ impl ClientHandler {
relay_result
}
/// Check user limits (expiration, connection count, data quota)
fn check_user_limits(&self, user: &str) -> Result<()> {
/// Check user limits (static version)
fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> {
// Check expiration
if let Some(expiration) = self.config.user_expirations.get(user) {
if let Some(expiration) = config.user_expirations.get(user) {
if chrono::Utc::now() > *expiration {
return Err(ProxyError::UserExpired { user: user.to_string() });
}
}
// Check connection limit
if let Some(limit) = self.config.user_max_tcp_conns.get(user) {
let current = self.stats.get_user_curr_connects(user);
if let Some(limit) = config.user_max_tcp_conns.get(user) {
let current = stats.get_user_curr_connects(user);
if current >= *limit as u64 {
return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() });
}
}
// Check data quota
if let Some(quota) = self.config.user_data_quota.get(user) {
let used = self.stats.get_user_total_octets(user);
if let Some(quota) = config.user_data_quota.get(user) {
let used = stats.get_user_total_octets(user);
if used >= *quota {
return Err(ProxyError::DataQuotaExceeded { user: user.to_string() });
}
@@ -316,11 +372,11 @@ impl ClientHandler {
Ok(())
}
/// Get datacenter address by index
fn get_dc_addr(&self, dc_idx: i16) -> Result<SocketAddr> {
/// Get datacenter address by index (static version)
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
let idx = (dc_idx.abs() - 1) as usize;
let datacenters = if self.config.prefer_ipv6 {
let datacenters = if config.prefer_ipv6 {
&*TG_DATACENTERS_V6
} else {
&*TG_DATACENTERS_V4
@@ -333,19 +389,18 @@ impl ClientHandler {
))
}
/// Perform handshake with Telegram server
/// Returns crypto reader and writer for TG connection
async fn do_tg_handshake(
&self,
/// Perform handshake with Telegram server (static version)
async fn do_tg_handshake_static(
mut stream: TcpStream,
success: &HandshakeSuccess,
config: &ProxyConfig,
) -> Result<(CryptoReader<tokio::net::tcp::OwnedReadHalf>, CryptoWriter<tokio::net::tcp::OwnedWriteHalf>)> {
// Generate nonce with keys for TG
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce(
success.proto_tag,
&success.dec_key, // Client's dec key
success.dec_iv,
self.config.fast_mode,
config.fast_mode,
);
// Encrypt nonce