This commit is contained in:
Alexey
2026-01-20 01:20:02 +03:00
parent 038f0cd5d1
commit 2ce8fbb2cc
11 changed files with 634 additions and 474 deletions

View File

@@ -14,7 +14,7 @@ use crate::protocol::constants::*;
use crate::protocol::tls;
use crate::stats::{Stats, ReplayChecker};
use crate::transport::{configure_client_socket, UpstreamManager};
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter};
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter, BufferPool};
use crate::crypto::AesCtr;
use super::handshake::{
@@ -35,6 +35,7 @@ pub struct RunningClientHandler {
stats: Arc<Stats>,
replay_checker: Arc<ReplayChecker>,
upstream_manager: Arc<UpstreamManager>,
buffer_pool: Arc<BufferPool>,
}
impl ClientHandler {
@@ -45,11 +46,9 @@ impl ClientHandler {
config: Arc<ProxyConfig>,
stats: Arc<Stats>,
upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>, // CHANGED: Accept global checker
replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>,
) -> RunningClientHandler {
// CHANGED: Removed local creation of ReplayChecker.
// It is now passed from main.rs to ensure global replay protection.
RunningClientHandler {
stream,
peer,
@@ -57,6 +56,7 @@ impl ClientHandler {
stats,
replay_checker,
upstream_manager,
buffer_pool,
}
}
}
@@ -72,14 +72,14 @@ impl RunningClientHandler {
// Configure socket
if let Err(e) = configure_client_socket(
&self.stream,
self.config.client_keepalive,
self.config.client_ack_timeout,
self.config.timeouts.client_keepalive,
self.config.timeouts.client_ack,
) {
debug!(peer = %peer, error = %e, "Failed to configure client socket");
}
// Perform handshake with timeout
let handshake_timeout = Duration::from_secs(self.config.client_handshake_timeout);
let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake);
// Clone stats for error handling block
let stats = self.stats.clone();
@@ -139,7 +139,9 @@ impl RunningClientHandler {
if tls_len < 512 {
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
self.stats.increment_connects_bad();
handle_bad_client(self.stream, &first_bytes, &self.config).await;
// FIX: Split stream into reader/writer for handle_bad_client
let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(());
}
@@ -152,6 +154,7 @@ impl RunningClientHandler {
let config = self.config.clone();
let replay_checker = self.replay_checker.clone();
let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone();
// Split stream for reading/writing
let (read_half, write_half) = self.stream.into_split();
@@ -166,8 +169,9 @@ impl RunningClientHandler {
&replay_checker,
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient => {
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(());
}
HandshakeResult::Error(e) => return Err(e),
@@ -190,27 +194,23 @@ impl RunningClientHandler {
true,
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient => {
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
// Valid TLS but invalid MTProto - drop
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake - dropping");
return Ok(());
}
HandshakeResult::Error(e) => return Err(e),
};
// Handle authenticated client
// We can't use self.handle_authenticated_inner because self is partially moved
// So we call it as an associated function or method on a new struct,
// or just inline the logic / use a static method.
// Since handle_authenticated_inner needs self.upstream_manager and self.stats,
// we should pass them explicitly.
Self::handle_authenticated_static(
crypto_reader,
crypto_writer,
success,
self.upstream_manager,
self.stats,
self.config
self.config,
buffer_pool
).await
}
@@ -222,10 +222,12 @@ impl RunningClientHandler {
let peer = self.peer;
// Check if non-TLS modes are enabled
if !self.config.modes.classic && !self.config.modes.secure {
if !self.config.general.modes.classic && !self.config.general.modes.secure {
debug!(peer = %peer, "Non-TLS modes disabled");
self.stats.increment_connects_bad();
handle_bad_client(self.stream, &first_bytes, &self.config).await;
// FIX: Split stream into reader/writer for handle_bad_client
let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(());
}
@@ -238,6 +240,7 @@ impl RunningClientHandler {
let config = self.config.clone();
let replay_checker = self.replay_checker.clone();
let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone();
// Split stream
let (read_half, write_half) = self.stream.into_split();
@@ -253,8 +256,9 @@ impl RunningClientHandler {
false,
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient => {
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(());
}
HandshakeResult::Error(e) => return Err(e),
@@ -266,11 +270,12 @@ impl RunningClientHandler {
success,
self.upstream_manager,
self.stats,
self.config
self.config,
buffer_pool
).await
}
/// Static version of handle_authenticated_inner to avoid ownership issues
/// Static version of handle_authenticated_inner
async fn handle_authenticated_static<R, W>(
client_reader: CryptoReader<R>,
client_writer: CryptoWriter<W>,
@@ -278,6 +283,7 @@ impl RunningClientHandler {
upstream_manager: Arc<UpstreamManager>,
stats: Arc<Stats>,
config: Arc<ProxyConfig>,
buffer_pool: Arc<BufferPool>,
) -> Result<()>
where
R: AsyncRead + Unpin + Send + 'static,
@@ -300,7 +306,7 @@ impl RunningClientHandler {
dc = success.dc_idx,
dc_addr = %dc_addr,
proto = ?success.proto_tag,
fast_mode = config.fast_mode,
fast_mode = config.general.fast_mode,
"Connecting to Telegram"
);
@@ -322,7 +328,7 @@ impl RunningClientHandler {
stats.increment_user_connects(user);
stats.increment_user_curr_connects(user);
// Relay traffic
// Relay traffic using buffer pool
let relay_result = relay_bidirectional(
client_reader,
client_writer,
@@ -330,6 +336,7 @@ impl RunningClientHandler {
tg_writer,
user,
Arc::clone(&stats),
buffer_pool,
).await;
// Update stats
@@ -346,14 +353,14 @@ impl RunningClientHandler {
/// Check user limits (static version)
fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> {
// Check expiration
if let Some(expiration) = config.user_expirations.get(user) {
if let Some(expiration) = config.access.user_expirations.get(user) {
if chrono::Utc::now() > *expiration {
return Err(ProxyError::UserExpired { user: user.to_string() });
}
}
// Check connection limit
if let Some(limit) = config.user_max_tcp_conns.get(user) {
if let Some(limit) = config.access.user_max_tcp_conns.get(user) {
let current = stats.get_user_curr_connects(user);
if current >= *limit as u64 {
return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() });
@@ -361,7 +368,7 @@ impl RunningClientHandler {
}
// Check data quota
if let Some(quota) = config.user_data_quota.get(user) {
if let Some(quota) = config.access.user_data_quota.get(user) {
let used = stats.get_user_total_octets(user);
if used >= *quota {
return Err(ProxyError::DataQuotaExceeded { user: user.to_string() });
@@ -375,7 +382,7 @@ impl RunningClientHandler {
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
let idx = (dc_idx.abs() - 1) as usize;
let datacenters = if config.prefer_ipv6 {
let datacenters = if config.general.prefer_ipv6 {
&*TG_DATACENTERS_V6
} else {
&*TG_DATACENTERS_V4
@@ -399,7 +406,7 @@ impl RunningClientHandler {
success.proto_tag,
&success.dec_key, // Client's dec key
success.dec_iv,
config.fast_mode,
config.general.fast_mode,
);
// Encrypt nonce

View File

@@ -42,7 +42,7 @@ pub async fn handle_tls_handshake<R, W>(
peer: SocketAddr,
config: &ProxyConfig,
replay_checker: &ReplayChecker,
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String)>
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String), R, W>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
@@ -52,7 +52,7 @@ where
// Check minimum length
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
debug!(peer = %peer, "TLS handshake too short");
return HandshakeResult::BadClient;
return HandshakeResult::BadClient { reader, writer };
}
// Extract digest for replay check
@@ -62,11 +62,11 @@ where
// Check for replay
if replay_checker.check_tls_digest(digest_half) {
warn!(peer = %peer, "TLS replay attack detected");
return HandshakeResult::BadClient;
return HandshakeResult::BadClient { reader, writer };
}
// Build secrets list
let secrets: Vec<(String, Vec<u8>)> = config.users.iter()
let secrets: Vec<(String, Vec<u8>)> = config.access.users.iter()
.filter_map(|(name, hex)| {
hex::decode(hex).ok().map(|bytes| (name.clone(), bytes))
})
@@ -78,19 +78,19 @@ where
let validation = match tls::validate_tls_handshake(
handshake,
&secrets,
config.ignore_time_skew,
config.access.ignore_time_skew,
) {
Some(v) => v,
None => {
debug!(peer = %peer, "TLS handshake validation failed - no matching user");
return HandshakeResult::BadClient;
return HandshakeResult::BadClient { reader, writer };
}
};
// Get secret for response
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
Some((_, s)) => s,
None => return HandshakeResult::BadClient,
None => return HandshakeResult::BadClient { reader, writer },
};
// Build and send response
@@ -98,7 +98,7 @@ where
secret,
&validation.digest,
&validation.session_id,
config.fake_cert_len,
config.censorship.fake_cert_len,
);
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
@@ -136,7 +136,7 @@ pub async fn handle_mtproto_handshake<R, W>(
config: &ProxyConfig,
replay_checker: &ReplayChecker,
is_tls: bool,
) -> HandshakeResult<(CryptoReader<R>, CryptoWriter<W>, HandshakeSuccess)>
) -> HandshakeResult<(CryptoReader<R>, CryptoWriter<W>, HandshakeSuccess), R, W>
where
R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send,
@@ -155,14 +155,14 @@ where
// Check for replay
if replay_checker.check_handshake(dec_prekey_iv) {
warn!(peer = %peer, "MTProto replay attack detected");
return HandshakeResult::BadClient;
return HandshakeResult::BadClient { reader, writer };
}
// Reversed for encryption direction
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
// Try each user's secret
for (user, secret_hex) in &config.users {
for (user, secret_hex) in &config.access.users {
let secret = match hex::decode(secret_hex) {
Ok(s) => s,
Err(_) => continue,
@@ -208,9 +208,9 @@ where
// Check if mode is enabled
let mode_ok = match proto_tag {
ProtoTag::Secure => {
if is_tls { config.modes.tls } else { config.modes.secure }
if is_tls { config.general.modes.tls } else { config.general.modes.secure }
}
ProtoTag::Intermediate | ProtoTag::Abridged => config.modes.classic,
ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic,
};
if !mode_ok {
@@ -270,7 +270,7 @@ where
}
debug!(peer = %peer, "MTProto handshake: no matching user found");
HandshakeResult::BadClient
HandshakeResult::BadClient { reader, writer }
}
/// Generate nonce for Telegram connection

View File

@@ -1,35 +1,73 @@
//! Masking - forward unrecognized traffic to mask host
use std::time::Duration;
use std::str;
use tokio::net::TcpStream;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tokio::time::timeout;
use tracing::debug;
use crate::config::ProxyConfig;
use crate::transport::set_linger_zero;
const MASK_TIMEOUT: Duration = Duration::from_secs(5);
const MASK_BUFFER_SIZE: usize = 8192;
/// Detect client type based on initial data
fn detect_client_type(data: &[u8]) -> &'static str {
// Check for HTTP request
if data.len() > 4 {
if data.starts_with(b"GET ") || data.starts_with(b"POST") ||
data.starts_with(b"HEAD") || data.starts_with(b"PUT ") ||
data.starts_with(b"DELETE") || data.starts_with(b"OPTIONS") {
return "HTTP";
}
}
// Check for TLS ClientHello (0x16 = handshake, 0x03 0x01-0x03 = TLS version)
if data.len() > 3 && data[0] == 0x16 && data[1] == 0x03 {
return "TLS-scanner";
}
// Check for SSH
if data.starts_with(b"SSH-") {
return "SSH";
}
// Port scanner (very short data)
if data.len() < 10 {
return "port-scanner";
}
"unknown"
}
/// Handle a bad client by forwarding to mask host
pub async fn handle_bad_client(
client: TcpStream,
pub async fn handle_bad_client<R, W>(
mut reader: R,
mut writer: W,
initial_data: &[u8],
config: &ProxyConfig,
) {
if !config.mask {
)
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
if !config.censorship.mask {
// Masking disabled, just consume data
consume_client_data(client).await;
consume_client_data(reader).await;
return;
}
let mask_host = config.mask_host.as_deref()
.unwrap_or(&config.tls_domain);
let mask_port = config.mask_port;
let client_type = detect_client_type(initial_data);
let mask_host = config.censorship.mask_host.as_deref()
.unwrap_or(&config.censorship.tls_domain);
let mask_port = config.censorship.mask_port;
debug!(
client_type = client_type,
host = %mask_host,
port = mask_port,
data_len = initial_data.len(),
"Forwarding bad client to mask host"
);
@@ -40,33 +78,32 @@ pub async fn handle_bad_client(
TcpStream::connect(&mask_addr)
).await;
let mut mask_stream = match connect_result {
let mask_stream = match connect_result {
Ok(Ok(s)) => s,
Ok(Err(e)) => {
debug!(error = %e, "Failed to connect to mask host");
consume_client_data(client).await;
consume_client_data(reader).await;
return;
}
Err(_) => {
debug!("Timeout connecting to mask host");
consume_client_data(client).await;
consume_client_data(reader).await;
return;
}
};
let (mut mask_read, mut mask_write) = mask_stream.into_split();
// Send initial data to mask host
if mask_stream.write_all(initial_data).await.is_err() {
if mask_write.write_all(initial_data).await.is_err() {
return;
}
// Relay traffic
let (mut client_read, mut client_write) = client.into_split();
let (mut mask_read, mut mask_write) = mask_stream.into_split();
let c2m = tokio::spawn(async move {
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
loop {
match client_read.read(&mut buf).await {
match reader.read(&mut buf).await {
Ok(0) | Err(_) => {
let _ = mask_write.shutdown().await;
break;
@@ -85,11 +122,11 @@ pub async fn handle_bad_client(
loop {
match mask_read.read(&mut buf).await {
Ok(0) | Err(_) => {
let _ = client_write.shutdown().await;
let _ = writer.shutdown().await;
break;
}
Ok(n) => {
if client_write.write_all(&buf[..n]).await.is_err() {
if writer.write_all(&buf[..n]).await.is_err() {
break;
}
}
@@ -105,9 +142,9 @@ pub async fn handle_bad_client(
}
/// Just consume all data from client without responding
async fn consume_client_data(mut client: TcpStream) {
async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R) {
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
while let Ok(n) = client.read(&mut buf).await {
while let Ok(n) = reader.read(&mut buf).await {
if n == 0 {
break;
}

View File

@@ -7,14 +7,10 @@ use tokio::time::Instant;
use tracing::{debug, trace, warn, info};
use crate::error::Result;
use crate::stats::Stats;
use crate::stream::BufferPool;
use std::sync::atomic::{AtomicU64, Ordering};
// CHANGED: Reduced from 128KB to 16KB to match TLS record size and prevent bufferbloat.
// This is critical for iOS clients to maintain proper TCP flow control during uploads.
const BUFFER_SIZE: usize = 16384;
// Activity timeout for iOS compatibility (30 minutes)
// iOS does not support TCP_USER_TIMEOUT, so we implement application-level timeout
const ACTIVITY_TIMEOUT_SECS: u64 = 1800;
/// Relay data bidirectionally between client and server
@@ -25,6 +21,7 @@ pub async fn relay_bidirectional<CR, CW, SR, SW>(
mut server_writer: SW,
user: &str,
stats: Arc<Stats>,
buffer_pool: Arc<BufferPool>,
) -> Result<()>
where
CR: AsyncRead + Unpin + Send + 'static,
@@ -35,7 +32,6 @@ where
let user_c2s = user.to_string();
let user_s2c = user.to_string();
// Используем Arc::clone вместо stats.clone()
let stats_c2s = Arc::clone(&stats);
let stats_s2c = Arc::clone(&stats);
@@ -44,26 +40,29 @@ where
let c2s_bytes_clone = Arc::clone(&c2s_bytes);
let s2c_bytes_clone = Arc::clone(&s2c_bytes);
// Activity timeout for iOS compatibility
let activity_timeout = Duration::from_secs(ACTIVITY_TIMEOUT_SECS);
// Client -> Server task with activity timeout
let pool_c2s = buffer_pool.clone();
let pool_s2c = buffer_pool.clone();
// Client -> Server task
let c2s = tokio::spawn(async move {
let mut buf = vec![0u8; BUFFER_SIZE];
// Get buffer from pool
let mut buf = pool_c2s.get();
let mut total_bytes = 0u64;
let mut prev_total_bytes = 0u64;
let mut msg_count = 0u64;
let mut last_activity = Instant::now();
let mut last_log = Instant::now();
loop {
// Read with timeout to prevent infinite hang on iOS
// Read with timeout
let read_result = tokio::time::timeout(
activity_timeout,
client_reader.read(&mut buf)
).await;
match read_result {
// Timeout - no activity for too long
Err(_) => {
warn!(
user = %user_c2s,
@@ -76,7 +75,6 @@ where
break;
}
// Read successful
Ok(Ok(0)) => {
debug!(
user = %user_c2s,
@@ -101,21 +99,26 @@ where
user = %user_c2s,
bytes = n,
total = total_bytes,
data_preview = %hex::encode(&buf[..n.min(32)]),
"C->S data"
);
// Log activity every 10 seconds for large transfers
if last_log.elapsed() > Duration::from_secs(10) {
let rate = total_bytes as f64 / last_log.elapsed().as_secs_f64();
info!(
// Log activity every 10 seconds with correct rate
let elapsed = last_log.elapsed();
if elapsed > Duration::from_secs(10) {
let delta = total_bytes - prev_total_bytes;
let rate = delta as f64 / elapsed.as_secs_f64();
// Changed to DEBUG to reduce log spam
debug!(
user = %user_c2s,
total_bytes = total_bytes,
msgs = msg_count,
rate_kbps = (rate / 1024.0) as u64,
"C->S transfer in progress"
);
last_log = Instant::now();
prev_total_bytes = total_bytes;
}
if let Err(e) = server_writer.write_all(&buf[..n]).await {
@@ -136,23 +139,23 @@ where
}
});
// Server -> Client task with activity timeout
// Server -> Client task
let s2c = tokio::spawn(async move {
let mut buf = vec![0u8; BUFFER_SIZE];
// Get buffer from pool
let mut buf = pool_s2c.get();
let mut total_bytes = 0u64;
let mut prev_total_bytes = 0u64;
let mut msg_count = 0u64;
let mut last_activity = Instant::now();
let mut last_log = Instant::now();
loop {
// Read with timeout to prevent infinite hang on iOS
let read_result = tokio::time::timeout(
activity_timeout,
server_reader.read(&mut buf)
).await;
match read_result {
// Timeout - no activity for too long
Err(_) => {
warn!(
user = %user_s2c,
@@ -165,7 +168,6 @@ where
break;
}
// Read successful
Ok(Ok(0)) => {
debug!(
user = %user_s2c,
@@ -190,21 +192,25 @@ where
user = %user_s2c,
bytes = n,
total = total_bytes,
data_preview = %hex::encode(&buf[..n.min(32)]),
"S->C data"
);
// Log activity every 10 seconds for large transfers
if last_log.elapsed() > Duration::from_secs(10) {
let rate = total_bytes as f64 / last_log.elapsed().as_secs_f64();
info!(
let elapsed = last_log.elapsed();
if elapsed > Duration::from_secs(10) {
let delta = total_bytes - prev_total_bytes;
let rate = delta as f64 / elapsed.as_secs_f64();
// Changed to DEBUG to reduce log spam
debug!(
user = %user_s2c,
total_bytes = total_bytes,
msgs = msg_count,
rate_kbps = (rate / 1024.0) as u64,
"S->C transfer in progress"
);
last_log = Instant::now();
prev_total_bytes = total_bytes;
}
if let Err(e) = client_writer.write_all(&buf[..n]).await {