This commit is contained in:
Alexey
2026-01-20 01:20:02 +03:00
parent 038f0cd5d1
commit 2ce8fbb2cc
11 changed files with 634 additions and 474 deletions

View File

@@ -118,44 +118,100 @@ then Ctrl+X -> Y -> Enter to save
## Configuration ## Configuration
### Minimal Configuration for First Start ### Minimal Configuration for First Start
```toml ```toml
port = 443 # Listening port # === General Settings ===
show_link = ["tele", "hello"] # Specify users, for whom will be displayed the links [general]
prefer_ipv6 = false
fast_mode = true
use_middle_proxy = false
# ad_tag = "..."
tls_domain = "petrovich.ru" # Domain for ee-secret and masking [general.modes]
mask = true # Enable masking of bad traffic classic = false
mask_host = "petrovich.ru" # Optional override for mask destination secure = false
mask_port = 443 # Port for masking tls = true
prefer_ipv6 = false # Try IPv6 DCs first if true # === Server Binding ===
fast_mode = true # Use "fast" obfuscation variant [server]
port = 443
listen_addr_ipv4 = "0.0.0.0"
listen_addr_ipv6 = "::"
# metrics_port = 9090
# metrics_whitelist = ["127.0.0.1", "::1"]
client_keepalive = 600 # Seconds # Listen on multiple interfaces/IPs (overrides listen_addr_*)
client_ack_timeout = 300 # Seconds [[server.listeners]]
ip = "0.0.0.0"
# announce_ip = "1.2.3.4" # Optional: Public IP for tg:// links
[modes] [[server.listeners]]
classic = true # Plain obfuscated mode ip = "::"
secure = true # dd-prefix mode
tls = true # Fake TLS (ee-prefix)
[users] # === Timeouts (in seconds) ===
hello = "00000000000000000000000000000000" # Replace the secret with one generated before [timeouts]
tele = "00000000000000000000000000000000" # Replace the secret with one generated before client_handshake = 15
tg_connect = 10
client_keepalive = 60
client_ack = 300
# === Anti-Censorship & Masking ===
[censorship]
tls_domain = "petrovich.ru"
mask = true
mask_port = 443
# mask_host = "petrovich.ru" # Defaults to tls_domain if not set
fake_cert_len = 2048
# === Access Control & Users ===
# username "hello" is used for example
[access]
replay_check_len = 65536
ignore_time_skew = false
[access.users]
# format: "username" = "32_hex_chars_secret"
hello = "00000000000000000000000000000000"
# [access.user_max_tcp_conns]
# hello = 50
# [access.user_data_quota]
# hello = 1073741824 # 1 GB
# === Upstreams & Routing ===
# By default, direct connection is used, but you can add SOCKS proxy
# Direct - Default
[[upstreams]]
type = "direct"
enabled = true
weight = 10
# SOCKS5
# [[upstreams]]
# type = "socks5"
# address = "127.0.0.1:9050"
# enabled = false
# weight = 1
# === UI ===
# Users to show in the startup log (tg:// links)
show_link = ["hello"]
``` ```
### Advanced ### Advanced
#### Adtag #### Adtag
To use channel advertising and usage statistics from Telegram, get Adtag from [@mtproxybot](https://t.me/mtproxybot), add this parameter to the end of config.toml and specify it To use channel advertising and usage statistics from Telegram, get Adtag from [@mtproxybot](https://t.me/mtproxybot), add this parameter to section `[General]`
```toml ```toml
ad_tag = "00000000000000000000000000000000" # Replace zeros to your adtag from @mtproxybot ad_tag = "00000000000000000000000000000000" # Replace zeros to your adtag from @mtproxybot
``` ```
#### Listening and Announce IPs #### Listening and Announce IPs
To specify listening address and/or address in links, add to the end of config.toml: To specify listening address and/or address in links, add to section `[[server.listeners]]` of config.toml:
```toml ```toml
[[listeners]] [[server.listeners]]
ip = "0.0.0.0" # 0.0.0.0 = all IPs; your IP = specific listening ip = "0.0.0.0" # 0.0.0.0 = all IPs; your IP = specific listening
announce_ip = "1.2.3.4" # IP in links; comment with # if not used announce_ip = "1.2.3.4" # IP in links; comment with # if not used
``` ```
#### Upstream Manager #### Upstream Manager
To specify upstream, add to the end of config.toml: To specify upstream, add to section `[[upstreams]]` of config.toml:
##### Bind on IP ##### Bind on IP
```toml ```toml
[[upstreams]] [[upstreams]]

View File

@@ -1,13 +1,78 @@
port = 443 # === General Settings ===
[general]
prefer_ipv6 = false
fast_mode = true
use_middle_proxy = false
# ad_tag = "..."
[users] [general.modes]
user1 = "00000000000000000000000000000000" classic = false
secure = false
[modes]
classic = true
secure = true
tls = true tls = true
tls_domain = "www.github.com" # === Server Binding ===
fast_mode = true [server]
prefer_ipv6 = false port = 443
listen_addr_ipv4 = "0.0.0.0"
listen_addr_ipv6 = "::"
# metrics_port = 9090
# metrics_whitelist = ["127.0.0.1", "::1"]
# Listen on multiple interfaces/IPs (overrides listen_addr_*)
[[server.listeners]]
ip = "0.0.0.0"
# announce_ip = "1.2.3.4" # Optional: Public IP for tg:// links
[[server.listeners]]
ip = "::"
# === Timeouts (in seconds) ===
[timeouts]
client_handshake = 15
tg_connect = 10
client_keepalive = 60
client_ack = 300
# === Anti-Censorship & Masking ===
[censorship]
tls_domain = "petrovich.ru"
mask = true
mask_port = 443
# mask_host = "petrovich.ru" # Defaults to tls_domain if not set
fake_cert_len = 2048
# === Access Control & Users ===
# username "hello" is used for example
[access]
replay_check_len = 65536
ignore_time_skew = false
[access.users]
# format: "username" = "32_hex_chars_secret"
hello = "00000000000000000000000000000000"
# [access.user_max_tcp_conns]
# hello = 50
# [access.user_data_quota]
# hello = 1073741824 # 1 GB
# === Upstreams & Routing ===
# By default, direct connection is used, but you can add SOCKS proxy
# Direct - Default
[[upstreams]]
type = "direct"
enabled = true
weight = 10
# SOCKS5
# [[upstreams]]
# type = "socks5"
# address = "127.0.0.1:9050"
# enabled = false
# weight = 1
# === UI ===
# Users to show in the startup log (tg:// links)
show_link = ["hello"]

View File

@@ -7,6 +7,29 @@ use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
// ============= Helper Defaults =============
fn default_true() -> bool { true }
fn default_port() -> u16 { 443 }
fn default_tls_domain() -> String { "www.google.com".to_string() }
fn default_mask_port() -> u16 { 443 }
fn default_replay_check_len() -> usize { 65536 }
fn default_handshake_timeout() -> u64 { 15 }
fn default_connect_timeout() -> u64 { 10 }
fn default_keepalive() -> u64 { 60 }
fn default_ack_timeout() -> u64 { 300 }
fn default_listen_addr() -> String { "0.0.0.0".to_string() }
fn default_fake_cert_len() -> usize { 2048 }
fn default_weight() -> u16 { 1 }
fn default_metrics_whitelist() -> Vec<IpAddr> {
vec![
"127.0.0.1".parse().unwrap(),
"::1".parse().unwrap(),
]
}
// ============= Sub-Configs =============
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProxyModes { pub struct ProxyModes {
#[serde(default)] #[serde(default)]
@@ -17,26 +40,185 @@ pub struct ProxyModes {
pub tls: bool, pub tls: bool,
} }
fn default_true() -> bool { true }
fn default_weight() -> u16 { 1 }
impl Default for ProxyModes { impl Default for ProxyModes {
fn default() -> Self { fn default() -> Self {
Self { classic: true, secure: true, tls: true } Self { classic: true, secure: true, tls: true }
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneralConfig {
#[serde(default)]
pub modes: ProxyModes,
#[serde(default)]
pub prefer_ipv6: bool,
#[serde(default = "default_true")]
pub fast_mode: bool,
#[serde(default)]
pub use_middle_proxy: bool,
#[serde(default)]
pub ad_tag: Option<String>,
}
impl Default for GeneralConfig {
fn default() -> Self {
Self {
modes: ProxyModes::default(),
prefer_ipv6: false,
fast_mode: true,
use_middle_proxy: false,
ad_tag: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
#[serde(default = "default_port")]
pub port: u16,
#[serde(default = "default_listen_addr")]
pub listen_addr_ipv4: String,
#[serde(default)]
pub listen_addr_ipv6: Option<String>,
#[serde(default)]
pub listen_unix_sock: Option<String>,
#[serde(default)]
pub metrics_port: Option<u16>,
#[serde(default = "default_metrics_whitelist")]
pub metrics_whitelist: Vec<IpAddr>,
#[serde(default)]
pub listeners: Vec<ListenerConfig>,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
port: default_port(),
listen_addr_ipv4: default_listen_addr(),
listen_addr_ipv6: Some("::".to_string()),
listen_unix_sock: None,
metrics_port: None,
metrics_whitelist: default_metrics_whitelist(),
listeners: Vec::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TimeoutsConfig {
#[serde(default = "default_handshake_timeout")]
pub client_handshake: u64,
#[serde(default = "default_connect_timeout")]
pub tg_connect: u64,
#[serde(default = "default_keepalive")]
pub client_keepalive: u64,
#[serde(default = "default_ack_timeout")]
pub client_ack: u64,
}
impl Default for TimeoutsConfig {
fn default() -> Self {
Self {
client_handshake: default_handshake_timeout(),
tg_connect: default_connect_timeout(),
client_keepalive: default_keepalive(),
client_ack: default_ack_timeout(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AntiCensorshipConfig {
#[serde(default = "default_tls_domain")]
pub tls_domain: String,
#[serde(default = "default_true")]
pub mask: bool,
#[serde(default)]
pub mask_host: Option<String>,
#[serde(default = "default_mask_port")]
pub mask_port: u16,
#[serde(default = "default_fake_cert_len")]
pub fake_cert_len: usize,
}
impl Default for AntiCensorshipConfig {
fn default() -> Self {
Self {
tls_domain: default_tls_domain(),
mask: true,
mask_host: None,
mask_port: default_mask_port(),
fake_cert_len: default_fake_cert_len(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessConfig {
#[serde(default)]
pub users: HashMap<String, String>,
#[serde(default)]
pub user_max_tcp_conns: HashMap<String, usize>,
#[serde(default)]
pub user_expirations: HashMap<String, DateTime<Utc>>,
#[serde(default)]
pub user_data_quota: HashMap<String, u64>,
#[serde(default = "default_replay_check_len")]
pub replay_check_len: usize,
#[serde(default)]
pub ignore_time_skew: bool,
}
impl Default for AccessConfig {
fn default() -> Self {
let mut users = HashMap::new();
users.insert("default".to_string(), "00000000000000000000000000000000".to_string());
Self {
users,
user_max_tcp_conns: HashMap::new(),
user_expirations: HashMap::new(),
user_data_quota: HashMap::new(),
replay_check_len: default_replay_check_len(),
ignore_time_skew: false,
}
}
}
// ============= Aux Structures =============
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "lowercase")] #[serde(tag = "type", rename_all = "lowercase")]
pub enum UpstreamType { pub enum UpstreamType {
Direct { Direct {
#[serde(default)] #[serde(default)]
interface: Option<String>, // Bind to specific IP/Interface interface: Option<String>,
}, },
Socks4 { Socks4 {
address: String, // IP:Port of SOCKS server address: String,
#[serde(default)] #[serde(default)]
interface: Option<String>, // Bind to specific IP/Interface for connection to SOCKS interface: Option<String>,
#[serde(default)] #[serde(default)]
user_id: Option<String>, user_id: Option<String>,
}, },
@@ -65,160 +247,35 @@ pub struct UpstreamConfig {
pub struct ListenerConfig { pub struct ListenerConfig {
pub ip: IpAddr, pub ip: IpAddr,
#[serde(default)] #[serde(default)]
pub announce_ip: Option<IpAddr>, // IP to show in tg:// links pub announce_ip: Option<IpAddr>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] // ============= Main Config =============
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ProxyConfig { pub struct ProxyConfig {
#[serde(default = "default_port")] #[serde(default)]
pub port: u16, pub general: GeneralConfig,
#[serde(default)] #[serde(default)]
pub users: HashMap<String, String>, pub server: ServerConfig,
#[serde(default)] #[serde(default)]
pub ad_tag: Option<String>, pub timeouts: TimeoutsConfig,
#[serde(default)] #[serde(default)]
pub modes: ProxyModes, pub censorship: AntiCensorshipConfig,
#[serde(default = "default_tls_domain")]
pub tls_domain: String,
#[serde(default = "default_true")]
pub mask: bool,
#[serde(default)] #[serde(default)]
pub mask_host: Option<String>, pub access: AccessConfig,
#[serde(default = "default_mask_port")]
pub mask_port: u16,
#[serde(default)]
pub prefer_ipv6: bool,
#[serde(default = "default_true")]
pub fast_mode: bool,
#[serde(default)]
pub use_middle_proxy: bool,
#[serde(default)]
pub user_max_tcp_conns: HashMap<String, usize>,
#[serde(default)]
pub user_expirations: HashMap<String, DateTime<Utc>>,
#[serde(default)]
pub user_data_quota: HashMap<String, u64>,
#[serde(default = "default_replay_check_len")]
pub replay_check_len: usize,
#[serde(default)]
pub ignore_time_skew: bool,
#[serde(default = "default_handshake_timeout")]
pub client_handshake_timeout: u64,
#[serde(default = "default_connect_timeout")]
pub tg_connect_timeout: u64,
#[serde(default = "default_keepalive")]
pub client_keepalive: u64,
#[serde(default = "default_ack_timeout")]
pub client_ack_timeout: u64,
#[serde(default = "default_listen_addr")]
pub listen_addr_ipv4: String,
#[serde(default)]
pub listen_addr_ipv6: Option<String>,
#[serde(default)]
pub listen_unix_sock: Option<String>,
#[serde(default)]
pub metrics_port: Option<u16>,
#[serde(default = "default_metrics_whitelist")]
pub metrics_whitelist: Vec<IpAddr>,
#[serde(default = "default_fake_cert_len")]
pub fake_cert_len: usize,
// New fields
#[serde(default)] #[serde(default)]
pub upstreams: Vec<UpstreamConfig>, pub upstreams: Vec<UpstreamConfig>,
#[serde(default)]
pub listeners: Vec<ListenerConfig>,
#[serde(default)] #[serde(default)]
pub show_link: Vec<String>, pub show_link: Vec<String>,
} }
fn default_port() -> u16 { 443 }
fn default_tls_domain() -> String { "www.google.com".to_string() }
fn default_mask_port() -> u16 { 443 }
fn default_replay_check_len() -> usize { 65536 }
// CHANGED: Increased handshake timeout for bad mobile networks
fn default_handshake_timeout() -> u64 { 15 }
fn default_connect_timeout() -> u64 { 10 }
// CHANGED: Reduced keepalive from 600s to 60s.
// Mobile NATs often drop idle connections after 60-120s.
fn default_keepalive() -> u64 { 60 }
fn default_ack_timeout() -> u64 { 300 }
fn default_listen_addr() -> String { "0.0.0.0".to_string() }
fn default_fake_cert_len() -> usize { 2048 }
fn default_metrics_whitelist() -> Vec<IpAddr> {
vec![
"127.0.0.1".parse().unwrap(),
"::1".parse().unwrap(),
]
}
impl Default for ProxyConfig {
fn default() -> Self {
let mut users = HashMap::new();
users.insert("default".to_string(), "00000000000000000000000000000000".to_string());
Self {
port: default_port(),
users,
ad_tag: None,
modes: ProxyModes::default(),
tls_domain: default_tls_domain(),
mask: true,
mask_host: None,
mask_port: default_mask_port(),
prefer_ipv6: false,
fast_mode: true,
use_middle_proxy: false,
user_max_tcp_conns: HashMap::new(),
user_expirations: HashMap::new(),
user_data_quota: HashMap::new(),
replay_check_len: default_replay_check_len(),
ignore_time_skew: false,
client_handshake_timeout: default_handshake_timeout(),
tg_connect_timeout: default_connect_timeout(),
client_keepalive: default_keepalive(),
client_ack_timeout: default_ack_timeout(),
listen_addr_ipv4: default_listen_addr(),
listen_addr_ipv6: Some("::".to_string()),
listen_unix_sock: None,
metrics_port: None,
metrics_whitelist: default_metrics_whitelist(),
fake_cert_len: default_fake_cert_len(),
upstreams: Vec::new(),
listeners: Vec::new(),
show_link: Vec::new(),
}
}
}
impl ProxyConfig { impl ProxyConfig {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> { pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let content = std::fs::read_to_string(path) let content = std::fs::read_to_string(path)
@@ -228,7 +285,7 @@ impl ProxyConfig {
.map_err(|e| ProxyError::Config(e.to_string()))?; .map_err(|e| ProxyError::Config(e.to_string()))?;
// Validate secrets // Validate secrets
for (user, secret) in &config.users { for (user, secret) in &config.access.users {
if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 { if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 {
return Err(ProxyError::InvalidSecret { return Err(ProxyError::InvalidSecret {
user: user.clone(), user: user.clone(),
@@ -237,26 +294,37 @@ impl ProxyConfig {
} }
} }
// Default mask_host // Validate tls_domain
if config.mask_host.is_none() { if config.censorship.tls_domain.is_empty() {
config.mask_host = Some(config.tls_domain.clone()); return Err(ProxyError::Config("tls_domain cannot be empty".to_string()));
}
// Warn if using default tls_domain
if config.censorship.tls_domain == "www.google.com" {
tracing::warn!("Using default tls_domain (www.google.com). Consider setting a custom domain in config.toml");
}
// Default mask_host to tls_domain if not set
if config.censorship.mask_host.is_none() {
tracing::info!("mask_host not set, using tls_domain ({}) for masking", config.censorship.tls_domain);
config.censorship.mask_host = Some(config.censorship.tls_domain.clone());
} }
// Random fake_cert_len // Random fake_cert_len
use rand::Rng; use rand::Rng;
config.fake_cert_len = rand::thread_rng().gen_range(1024..4096); config.censorship.fake_cert_len = rand::thread_rng().gen_range(1024..4096);
// Migration: Populate listeners if empty // Migration: Populate listeners if empty
if config.listeners.is_empty() { if config.server.listeners.is_empty() {
if let Ok(ipv4) = config.listen_addr_ipv4.parse::<IpAddr>() { if let Ok(ipv4) = config.server.listen_addr_ipv4.parse::<IpAddr>() {
config.listeners.push(ListenerConfig { config.server.listeners.push(ListenerConfig {
ip: ipv4, ip: ipv4,
announce_ip: None, announce_ip: None,
}); });
} }
if let Some(ipv6_str) = &config.listen_addr_ipv6 { if let Some(ipv6_str) = &config.server.listen_addr_ipv6 {
if let Ok(ipv6) = ipv6_str.parse::<IpAddr>() { if let Ok(ipv6) = ipv6_str.parse::<IpAddr>() {
config.listeners.push(ListenerConfig { config.server.listeners.push(ListenerConfig {
ip: ipv6, ip: ipv6,
announce_ip: None, announce_ip: None,
}); });
@@ -277,14 +345,21 @@ impl ProxyConfig {
} }
pub fn validate(&self) -> Result<()> { pub fn validate(&self) -> Result<()> {
if self.users.is_empty() { if self.access.users.is_empty() {
return Err(ProxyError::Config("No users configured".to_string())); return Err(ProxyError::Config("No users configured".to_string()));
} }
if !self.modes.classic && !self.modes.secure && !self.modes.tls { if !self.general.modes.classic && !self.general.modes.secure && !self.general.modes.tls {
return Err(ProxyError::Config("No modes enabled".to_string())); return Err(ProxyError::Config("No modes enabled".to_string()));
} }
// Validate tls_domain format (basic check)
if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') {
return Err(ProxyError::Config(
format!("Invalid tls_domain: '{}'. Must be a valid domain name", self.censorship.tls_domain)
));
}
Ok(()) Ok(())
} }
} }

View File

@@ -297,16 +297,16 @@ pub type StreamResult<T> = std::result::Result<T, StreamError>;
/// Result with optional bad client handling /// Result with optional bad client handling
#[derive(Debug)] #[derive(Debug)]
pub enum HandshakeResult<T> { pub enum HandshakeResult<T, R, W> {
/// Handshake succeeded /// Handshake succeeded
Success(T), Success(T),
/// Client failed validation, needs masking /// Client failed validation, needs masking. Returns ownership of streams.
BadClient, BadClient { reader: R, writer: W },
/// Error occurred /// Error occurred
Error(ProxyError), Error(ProxyError),
} }
impl<T> HandshakeResult<T> { impl<T, R, W> HandshakeResult<T, R, W> {
/// Check if successful /// Check if successful
pub fn is_success(&self) -> bool { pub fn is_success(&self) -> bool {
matches!(self, HandshakeResult::Success(_)) matches!(self, HandshakeResult::Success(_))
@@ -314,49 +314,32 @@ impl<T> HandshakeResult<T> {
/// Check if bad client /// Check if bad client
pub fn is_bad_client(&self) -> bool { pub fn is_bad_client(&self) -> bool {
matches!(self, HandshakeResult::BadClient) matches!(self, HandshakeResult::BadClient { .. })
}
/// Convert to Result, treating BadClient as error
pub fn into_result(self) -> Result<T> {
match self {
HandshakeResult::Success(v) => Ok(v),
HandshakeResult::BadClient => Err(ProxyError::InvalidHandshake("Bad client".into())),
HandshakeResult::Error(e) => Err(e),
}
} }
/// Map the success value /// Map the success value
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> HandshakeResult<U> { pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> HandshakeResult<U, R, W> {
match self { match self {
HandshakeResult::Success(v) => HandshakeResult::Success(f(v)), HandshakeResult::Success(v) => HandshakeResult::Success(f(v)),
HandshakeResult::BadClient => HandshakeResult::BadClient, HandshakeResult::BadClient { reader, writer } => HandshakeResult::BadClient { reader, writer },
HandshakeResult::Error(e) => HandshakeResult::Error(e), HandshakeResult::Error(e) => HandshakeResult::Error(e),
} }
} }
/// Convert success to Option
pub fn ok(self) -> Option<T> {
match self {
HandshakeResult::Success(v) => Some(v),
_ => None,
}
}
} }
impl<T> From<ProxyError> for HandshakeResult<T> { impl<T, R, W> From<ProxyError> for HandshakeResult<T, R, W> {
fn from(err: ProxyError) -> Self { fn from(err: ProxyError) -> Self {
HandshakeResult::Error(err) HandshakeResult::Error(err)
} }
} }
impl<T> From<std::io::Error> for HandshakeResult<T> { impl<T, R, W> From<std::io::Error> for HandshakeResult<T, R, W> {
fn from(err: std::io::Error) -> Self { fn from(err: std::io::Error) -> Self {
HandshakeResult::Error(ProxyError::Io(err)) HandshakeResult::Error(ProxyError::Io(err))
} }
} }
impl<T> From<StreamError> for HandshakeResult<T> { impl<T, R, W> From<StreamError> for HandshakeResult<T, R, W> {
fn from(err: StreamError) -> Self { fn from(err: StreamError) -> Self {
HandshakeResult::Error(ProxyError::Stream(err)) HandshakeResult::Error(ProxyError::Stream(err))
} }

View File

@@ -23,6 +23,7 @@ use crate::proxy::ClientHandler;
use crate::stats::{Stats, ReplayChecker}; use crate::stats::{Stats, ReplayChecker};
use crate::transport::{create_listener, ListenOptions, UpstreamManager}; use crate::transport::{create_listener, ListenOptions, UpstreamManager};
use crate::util::ip::detect_ip; use crate::util::ip::detect_ip;
use crate::stream::BufferPool;
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
@@ -52,15 +53,33 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
config.validate()?; config.validate()?;
// Log loaded configuration for debugging
info!("=== Configuration Loaded ===");
info!("TLS Domain: {}", config.censorship.tls_domain);
info!("Mask enabled: {}", config.censorship.mask);
info!("Mask host: {}", config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain));
info!("Mask port: {}", config.censorship.mask_port);
info!("Modes: classic={}, secure={}, tls={}",
config.general.modes.classic,
config.general.modes.secure,
config.general.modes.tls
);
info!("============================");
let config = Arc::new(config); let config = Arc::new(config);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
// CHANGED: Initialize global ReplayChecker here instead of per-connection // Initialize global ReplayChecker
let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len)); // Using sharded implementation for better concurrency
let replay_checker = Arc::new(ReplayChecker::new(config.access.replay_check_len));
// Initialize Upstream Manager // Initialize Upstream Manager
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone())); let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
// Initialize Buffer Pool
// 16KB buffers, max 4096 buffers (~64MB total cached)
let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096));
// Start Health Checks // Start Health Checks
let um_clone = upstream_manager.clone(); let um_clone = upstream_manager.clone();
tokio::spawn(async move { tokio::spawn(async move {
@@ -73,8 +92,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Start Listeners // Start Listeners
let mut listeners = Vec::new(); let mut listeners = Vec::new();
for listener_conf in &config.listeners { for listener_conf in &config.server.listeners {
let addr = SocketAddr::new(listener_conf.ip, config.port); let addr = SocketAddr::new(listener_conf.ip, config.server.port);
let options = ListenOptions { let options = ListenOptions {
ipv6_only: listener_conf.ip.is_ipv6(), ipv6_only: listener_conf.ip.is_ipv6(),
..Default::default() ..Default::default()
@@ -86,13 +105,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
info!("Listening on {}", addr); info!("Listening on {}", addr);
// Determine public IP for tg:// links // Determine public IP for tg:// links
// 1. Use explicit announce_ip if set
// 2. If listening on 0.0.0.0 or ::, use detected public IP
// 3. Otherwise use the bind IP
let public_ip = if let Some(ip) = listener_conf.announce_ip { let public_ip = if let Some(ip) = listener_conf.announce_ip {
ip ip
} else if listener_conf.ip.is_unspecified() { } else if listener_conf.ip.is_unspecified() {
// Try to use detected IP of the same family
if listener_conf.ip.is_ipv4() { if listener_conf.ip.is_ipv4() {
detected_ip.ipv4.unwrap_or(listener_conf.ip) detected_ip.ipv4.unwrap_or(listener_conf.ip)
} else { } else {
@@ -106,26 +121,23 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
if !config.show_link.is_empty() { if !config.show_link.is_empty() {
info!("--- Proxy Links for {} ---", public_ip); info!("--- Proxy Links for {} ---", public_ip);
for user_name in &config.show_link { for user_name in &config.show_link {
if let Some(secret) = config.users.get(user_name) { if let Some(secret) = config.access.users.get(user_name) {
info!("User: {}", user_name); info!("User: {}", user_name);
// Classic if config.general.modes.classic {
if config.modes.classic {
info!(" Classic: tg://proxy?server={}&port={}&secret={}", info!(" Classic: tg://proxy?server={}&port={}&secret={}",
public_ip, config.port, secret); public_ip, config.server.port, secret);
} }
// DD (Secure) if config.general.modes.secure {
if config.modes.secure {
info!(" DD: tg://proxy?server={}&port={}&secret=dd{}", info!(" DD: tg://proxy?server={}&port={}&secret=dd{}",
public_ip, config.port, secret); public_ip, config.server.port, secret);
} }
// EE-TLS (FakeTLS) if config.general.modes.tls {
if config.modes.tls { let domain_hex = hex::encode(&config.censorship.tls_domain);
let domain_hex = hex::encode(&config.tls_domain);
info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
public_ip, config.port, secret, domain_hex); public_ip, config.server.port, secret, domain_hex);
} }
} else { } else {
warn!("User '{}' specified in show_link not found in users list", user_name); warn!("User '{}' specified in show_link not found in users list", user_name);
@@ -153,6 +165,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let stats = stats.clone(); let stats = stats.clone();
let upstream_manager = upstream_manager.clone(); let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone(); let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
@@ -162,6 +175,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let stats = stats.clone(); let stats = stats.clone();
let upstream_manager = upstream_manager.clone(); let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone(); let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = ClientHandler::new( if let Err(e) = ClientHandler::new(
@@ -170,10 +184,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
config, config,
stats, stats,
upstream_manager, upstream_manager,
replay_checker // Pass global checker replay_checker,
buffer_pool
).run().await { ).run().await {
// Log only relevant errors // Log only relevant errors
// debug!("Connection error: {}", e);
} }
}); });
} }

View File

@@ -14,7 +14,7 @@ use crate::protocol::constants::*;
use crate::protocol::tls; use crate::protocol::tls;
use crate::stats::{Stats, ReplayChecker}; use crate::stats::{Stats, ReplayChecker};
use crate::transport::{configure_client_socket, UpstreamManager}; use crate::transport::{configure_client_socket, UpstreamManager};
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter}; use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter, BufferPool};
use crate::crypto::AesCtr; use crate::crypto::AesCtr;
use super::handshake::{ use super::handshake::{
@@ -35,6 +35,7 @@ pub struct RunningClientHandler {
stats: Arc<Stats>, stats: Arc<Stats>,
replay_checker: Arc<ReplayChecker>, replay_checker: Arc<ReplayChecker>,
upstream_manager: Arc<UpstreamManager>, upstream_manager: Arc<UpstreamManager>,
buffer_pool: Arc<BufferPool>,
} }
impl ClientHandler { impl ClientHandler {
@@ -45,11 +46,9 @@ impl ClientHandler {
config: Arc<ProxyConfig>, config: Arc<ProxyConfig>,
stats: Arc<Stats>, stats: Arc<Stats>,
upstream_manager: Arc<UpstreamManager>, upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>, // CHANGED: Accept global checker replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>,
) -> RunningClientHandler { ) -> RunningClientHandler {
// CHANGED: Removed local creation of ReplayChecker.
// It is now passed from main.rs to ensure global replay protection.
RunningClientHandler { RunningClientHandler {
stream, stream,
peer, peer,
@@ -57,6 +56,7 @@ impl ClientHandler {
stats, stats,
replay_checker, replay_checker,
upstream_manager, upstream_manager,
buffer_pool,
} }
} }
} }
@@ -72,14 +72,14 @@ impl RunningClientHandler {
// Configure socket // Configure socket
if let Err(e) = configure_client_socket( if let Err(e) = configure_client_socket(
&self.stream, &self.stream,
self.config.client_keepalive, self.config.timeouts.client_keepalive,
self.config.client_ack_timeout, self.config.timeouts.client_ack,
) { ) {
debug!(peer = %peer, error = %e, "Failed to configure client socket"); debug!(peer = %peer, error = %e, "Failed to configure client socket");
} }
// Perform handshake with timeout // Perform handshake with timeout
let handshake_timeout = Duration::from_secs(self.config.client_handshake_timeout); let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake);
// Clone stats for error handling block // Clone stats for error handling block
let stats = self.stats.clone(); let stats = self.stats.clone();
@@ -139,7 +139,9 @@ impl RunningClientHandler {
if tls_len < 512 { if tls_len < 512 {
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
self.stats.increment_connects_bad(); self.stats.increment_connects_bad();
handle_bad_client(self.stream, &first_bytes, &self.config).await; // FIX: Split stream into reader/writer for handle_bad_client
let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(()); return Ok(());
} }
@@ -152,6 +154,7 @@ impl RunningClientHandler {
let config = self.config.clone(); let config = self.config.clone();
let replay_checker = self.replay_checker.clone(); let replay_checker = self.replay_checker.clone();
let stats = self.stats.clone(); let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone();
// Split stream for reading/writing // Split stream for reading/writing
let (read_half, write_half) = self.stream.into_split(); let (read_half, write_half) = self.stream.into_split();
@@ -166,8 +169,9 @@ impl RunningClientHandler {
&replay_checker, &replay_checker,
).await { ).await {
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient => { HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad(); stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(()); return Ok(());
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
@@ -190,27 +194,23 @@ impl RunningClientHandler {
true, true,
).await { ).await {
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient => { HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad(); stats.increment_connects_bad();
// Valid TLS but invalid MTProto - drop
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake - dropping");
return Ok(()); return Ok(());
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
// Handle authenticated client
// We can't use self.handle_authenticated_inner because self is partially moved
// So we call it as an associated function or method on a new struct,
// or just inline the logic / use a static method.
// Since handle_authenticated_inner needs self.upstream_manager and self.stats,
// we should pass them explicitly.
Self::handle_authenticated_static( Self::handle_authenticated_static(
crypto_reader, crypto_reader,
crypto_writer, crypto_writer,
success, success,
self.upstream_manager, self.upstream_manager,
self.stats, self.stats,
self.config self.config,
buffer_pool
).await ).await
} }
@@ -222,10 +222,12 @@ impl RunningClientHandler {
let peer = self.peer; let peer = self.peer;
// Check if non-TLS modes are enabled // Check if non-TLS modes are enabled
if !self.config.modes.classic && !self.config.modes.secure { if !self.config.general.modes.classic && !self.config.general.modes.secure {
debug!(peer = %peer, "Non-TLS modes disabled"); debug!(peer = %peer, "Non-TLS modes disabled");
self.stats.increment_connects_bad(); self.stats.increment_connects_bad();
handle_bad_client(self.stream, &first_bytes, &self.config).await; // FIX: Split stream into reader/writer for handle_bad_client
let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(()); return Ok(());
} }
@@ -238,6 +240,7 @@ impl RunningClientHandler {
let config = self.config.clone(); let config = self.config.clone();
let replay_checker = self.replay_checker.clone(); let replay_checker = self.replay_checker.clone();
let stats = self.stats.clone(); let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone();
// Split stream // Split stream
let (read_half, write_half) = self.stream.into_split(); let (read_half, write_half) = self.stream.into_split();
@@ -253,8 +256,9 @@ impl RunningClientHandler {
false, false,
).await { ).await {
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient => { HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad(); stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(()); return Ok(());
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
@@ -266,11 +270,12 @@ impl RunningClientHandler {
success, success,
self.upstream_manager, self.upstream_manager,
self.stats, self.stats,
self.config self.config,
buffer_pool
).await ).await
} }
/// Static version of handle_authenticated_inner to avoid ownership issues /// Static version of handle_authenticated_inner
async fn handle_authenticated_static<R, W>( async fn handle_authenticated_static<R, W>(
client_reader: CryptoReader<R>, client_reader: CryptoReader<R>,
client_writer: CryptoWriter<W>, client_writer: CryptoWriter<W>,
@@ -278,6 +283,7 @@ impl RunningClientHandler {
upstream_manager: Arc<UpstreamManager>, upstream_manager: Arc<UpstreamManager>,
stats: Arc<Stats>, stats: Arc<Stats>,
config: Arc<ProxyConfig>, config: Arc<ProxyConfig>,
buffer_pool: Arc<BufferPool>,
) -> Result<()> ) -> Result<()>
where where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
@@ -300,7 +306,7 @@ impl RunningClientHandler {
dc = success.dc_idx, dc = success.dc_idx,
dc_addr = %dc_addr, dc_addr = %dc_addr,
proto = ?success.proto_tag, proto = ?success.proto_tag,
fast_mode = config.fast_mode, fast_mode = config.general.fast_mode,
"Connecting to Telegram" "Connecting to Telegram"
); );
@@ -322,7 +328,7 @@ impl RunningClientHandler {
stats.increment_user_connects(user); stats.increment_user_connects(user);
stats.increment_user_curr_connects(user); stats.increment_user_curr_connects(user);
// Relay traffic // Relay traffic using buffer pool
let relay_result = relay_bidirectional( let relay_result = relay_bidirectional(
client_reader, client_reader,
client_writer, client_writer,
@@ -330,6 +336,7 @@ impl RunningClientHandler {
tg_writer, tg_writer,
user, user,
Arc::clone(&stats), Arc::clone(&stats),
buffer_pool,
).await; ).await;
// Update stats // Update stats
@@ -346,14 +353,14 @@ impl RunningClientHandler {
/// Check user limits (static version) /// Check user limits (static version)
fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> { fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> {
// Check expiration // Check expiration
if let Some(expiration) = config.user_expirations.get(user) { if let Some(expiration) = config.access.user_expirations.get(user) {
if chrono::Utc::now() > *expiration { if chrono::Utc::now() > *expiration {
return Err(ProxyError::UserExpired { user: user.to_string() }); return Err(ProxyError::UserExpired { user: user.to_string() });
} }
} }
// Check connection limit // Check connection limit
if let Some(limit) = config.user_max_tcp_conns.get(user) { if let Some(limit) = config.access.user_max_tcp_conns.get(user) {
let current = stats.get_user_curr_connects(user); let current = stats.get_user_curr_connects(user);
if current >= *limit as u64 { if current >= *limit as u64 {
return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() }); return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() });
@@ -361,7 +368,7 @@ impl RunningClientHandler {
} }
// Check data quota // Check data quota
if let Some(quota) = config.user_data_quota.get(user) { if let Some(quota) = config.access.user_data_quota.get(user) {
let used = stats.get_user_total_octets(user); let used = stats.get_user_total_octets(user);
if used >= *quota { if used >= *quota {
return Err(ProxyError::DataQuotaExceeded { user: user.to_string() }); return Err(ProxyError::DataQuotaExceeded { user: user.to_string() });
@@ -375,7 +382,7 @@ impl RunningClientHandler {
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> { fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
let idx = (dc_idx.abs() - 1) as usize; let idx = (dc_idx.abs() - 1) as usize;
let datacenters = if config.prefer_ipv6 { let datacenters = if config.general.prefer_ipv6 {
&*TG_DATACENTERS_V6 &*TG_DATACENTERS_V6
} else { } else {
&*TG_DATACENTERS_V4 &*TG_DATACENTERS_V4
@@ -399,7 +406,7 @@ impl RunningClientHandler {
success.proto_tag, success.proto_tag,
&success.dec_key, // Client's dec key &success.dec_key, // Client's dec key
success.dec_iv, success.dec_iv,
config.fast_mode, config.general.fast_mode,
); );
// Encrypt nonce // Encrypt nonce

View File

@@ -42,7 +42,7 @@ pub async fn handle_tls_handshake<R, W>(
peer: SocketAddr, peer: SocketAddr,
config: &ProxyConfig, config: &ProxyConfig,
replay_checker: &ReplayChecker, replay_checker: &ReplayChecker,
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String)> ) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String), R, W>
where where
R: AsyncRead + Unpin, R: AsyncRead + Unpin,
W: AsyncWrite + Unpin, W: AsyncWrite + Unpin,
@@ -52,7 +52,7 @@ where
// Check minimum length // Check minimum length
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
debug!(peer = %peer, "TLS handshake too short"); debug!(peer = %peer, "TLS handshake too short");
return HandshakeResult::BadClient; return HandshakeResult::BadClient { reader, writer };
} }
// Extract digest for replay check // Extract digest for replay check
@@ -62,11 +62,11 @@ where
// Check for replay // Check for replay
if replay_checker.check_tls_digest(digest_half) { if replay_checker.check_tls_digest(digest_half) {
warn!(peer = %peer, "TLS replay attack detected"); warn!(peer = %peer, "TLS replay attack detected");
return HandshakeResult::BadClient; return HandshakeResult::BadClient { reader, writer };
} }
// Build secrets list // Build secrets list
let secrets: Vec<(String, Vec<u8>)> = config.users.iter() let secrets: Vec<(String, Vec<u8>)> = config.access.users.iter()
.filter_map(|(name, hex)| { .filter_map(|(name, hex)| {
hex::decode(hex).ok().map(|bytes| (name.clone(), bytes)) hex::decode(hex).ok().map(|bytes| (name.clone(), bytes))
}) })
@@ -78,19 +78,19 @@ where
let validation = match tls::validate_tls_handshake( let validation = match tls::validate_tls_handshake(
handshake, handshake,
&secrets, &secrets,
config.ignore_time_skew, config.access.ignore_time_skew,
) { ) {
Some(v) => v, Some(v) => v,
None => { None => {
debug!(peer = %peer, "TLS handshake validation failed - no matching user"); debug!(peer = %peer, "TLS handshake validation failed - no matching user");
return HandshakeResult::BadClient; return HandshakeResult::BadClient { reader, writer };
} }
}; };
// Get secret for response // Get secret for response
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
Some((_, s)) => s, Some((_, s)) => s,
None => return HandshakeResult::BadClient, None => return HandshakeResult::BadClient { reader, writer },
}; };
// Build and send response // Build and send response
@@ -98,7 +98,7 @@ where
secret, secret,
&validation.digest, &validation.digest,
&validation.session_id, &validation.session_id,
config.fake_cert_len, config.censorship.fake_cert_len,
); );
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello"); debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
@@ -136,7 +136,7 @@ pub async fn handle_mtproto_handshake<R, W>(
config: &ProxyConfig, config: &ProxyConfig,
replay_checker: &ReplayChecker, replay_checker: &ReplayChecker,
is_tls: bool, is_tls: bool,
) -> HandshakeResult<(CryptoReader<R>, CryptoWriter<W>, HandshakeSuccess)> ) -> HandshakeResult<(CryptoReader<R>, CryptoWriter<W>, HandshakeSuccess), R, W>
where where
R: AsyncRead + Unpin + Send, R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send, W: AsyncWrite + Unpin + Send,
@@ -155,14 +155,14 @@ where
// Check for replay // Check for replay
if replay_checker.check_handshake(dec_prekey_iv) { if replay_checker.check_handshake(dec_prekey_iv) {
warn!(peer = %peer, "MTProto replay attack detected"); warn!(peer = %peer, "MTProto replay attack detected");
return HandshakeResult::BadClient; return HandshakeResult::BadClient { reader, writer };
} }
// Reversed for encryption direction // Reversed for encryption direction
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect(); let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
// Try each user's secret // Try each user's secret
for (user, secret_hex) in &config.users { for (user, secret_hex) in &config.access.users {
let secret = match hex::decode(secret_hex) { let secret = match hex::decode(secret_hex) {
Ok(s) => s, Ok(s) => s,
Err(_) => continue, Err(_) => continue,
@@ -208,9 +208,9 @@ where
// Check if mode is enabled // Check if mode is enabled
let mode_ok = match proto_tag { let mode_ok = match proto_tag {
ProtoTag::Secure => { ProtoTag::Secure => {
if is_tls { config.modes.tls } else { config.modes.secure } if is_tls { config.general.modes.tls } else { config.general.modes.secure }
} }
ProtoTag::Intermediate | ProtoTag::Abridged => config.modes.classic, ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic,
}; };
if !mode_ok { if !mode_ok {
@@ -270,7 +270,7 @@ where
} }
debug!(peer = %peer, "MTProto handshake: no matching user found"); debug!(peer = %peer, "MTProto handshake: no matching user found");
HandshakeResult::BadClient HandshakeResult::BadClient { reader, writer }
} }
/// Generate nonce for Telegram connection /// Generate nonce for Telegram connection

View File

@@ -1,35 +1,73 @@
//! Masking - forward unrecognized traffic to mask host //! Masking - forward unrecognized traffic to mask host
use std::time::Duration; use std::time::Duration;
use std::str;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tokio::time::timeout; use tokio::time::timeout;
use tracing::debug; use tracing::debug;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::transport::set_linger_zero;
const MASK_TIMEOUT: Duration = Duration::from_secs(5); const MASK_TIMEOUT: Duration = Duration::from_secs(5);
const MASK_BUFFER_SIZE: usize = 8192; const MASK_BUFFER_SIZE: usize = 8192;
/// Detect client type based on initial data
fn detect_client_type(data: &[u8]) -> &'static str {
// Check for HTTP request
if data.len() > 4 {
if data.starts_with(b"GET ") || data.starts_with(b"POST") ||
data.starts_with(b"HEAD") || data.starts_with(b"PUT ") ||
data.starts_with(b"DELETE") || data.starts_with(b"OPTIONS") {
return "HTTP";
}
}
// Check for TLS ClientHello (0x16 = handshake, 0x03 0x01-0x03 = TLS version)
if data.len() > 3 && data[0] == 0x16 && data[1] == 0x03 {
return "TLS-scanner";
}
// Check for SSH
if data.starts_with(b"SSH-") {
return "SSH";
}
// Port scanner (very short data)
if data.len() < 10 {
return "port-scanner";
}
"unknown"
}
/// Handle a bad client by forwarding to mask host /// Handle a bad client by forwarding to mask host
pub async fn handle_bad_client( pub async fn handle_bad_client<R, W>(
client: TcpStream, mut reader: R,
mut writer: W,
initial_data: &[u8], initial_data: &[u8],
config: &ProxyConfig, config: &ProxyConfig,
) { )
if !config.mask { where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
if !config.censorship.mask {
// Masking disabled, just consume data // Masking disabled, just consume data
consume_client_data(client).await; consume_client_data(reader).await;
return; return;
} }
let mask_host = config.mask_host.as_deref() let client_type = detect_client_type(initial_data);
.unwrap_or(&config.tls_domain);
let mask_port = config.mask_port; let mask_host = config.censorship.mask_host.as_deref()
.unwrap_or(&config.censorship.tls_domain);
let mask_port = config.censorship.mask_port;
debug!( debug!(
client_type = client_type,
host = %mask_host, host = %mask_host,
port = mask_port, port = mask_port,
data_len = initial_data.len(),
"Forwarding bad client to mask host" "Forwarding bad client to mask host"
); );
@@ -40,33 +78,32 @@ pub async fn handle_bad_client(
TcpStream::connect(&mask_addr) TcpStream::connect(&mask_addr)
).await; ).await;
let mut mask_stream = match connect_result { let mask_stream = match connect_result {
Ok(Ok(s)) => s, Ok(Ok(s)) => s,
Ok(Err(e)) => { Ok(Err(e)) => {
debug!(error = %e, "Failed to connect to mask host"); debug!(error = %e, "Failed to connect to mask host");
consume_client_data(client).await; consume_client_data(reader).await;
return; return;
} }
Err(_) => { Err(_) => {
debug!("Timeout connecting to mask host"); debug!("Timeout connecting to mask host");
consume_client_data(client).await; consume_client_data(reader).await;
return; return;
} }
}; };
let (mut mask_read, mut mask_write) = mask_stream.into_split();
// Send initial data to mask host // Send initial data to mask host
if mask_stream.write_all(initial_data).await.is_err() { if mask_write.write_all(initial_data).await.is_err() {
return; return;
} }
// Relay traffic // Relay traffic
let (mut client_read, mut client_write) = client.into_split();
let (mut mask_read, mut mask_write) = mask_stream.into_split();
let c2m = tokio::spawn(async move { let c2m = tokio::spawn(async move {
let mut buf = vec![0u8; MASK_BUFFER_SIZE]; let mut buf = vec![0u8; MASK_BUFFER_SIZE];
loop { loop {
match client_read.read(&mut buf).await { match reader.read(&mut buf).await {
Ok(0) | Err(_) => { Ok(0) | Err(_) => {
let _ = mask_write.shutdown().await; let _ = mask_write.shutdown().await;
break; break;
@@ -85,11 +122,11 @@ pub async fn handle_bad_client(
loop { loop {
match mask_read.read(&mut buf).await { match mask_read.read(&mut buf).await {
Ok(0) | Err(_) => { Ok(0) | Err(_) => {
let _ = client_write.shutdown().await; let _ = writer.shutdown().await;
break; break;
} }
Ok(n) => { Ok(n) => {
if client_write.write_all(&buf[..n]).await.is_err() { if writer.write_all(&buf[..n]).await.is_err() {
break; break;
} }
} }
@@ -105,9 +142,9 @@ pub async fn handle_bad_client(
} }
/// Just consume all data from client without responding /// Just consume all data from client without responding
async fn consume_client_data(mut client: TcpStream) { async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R) {
let mut buf = vec![0u8; MASK_BUFFER_SIZE]; let mut buf = vec![0u8; MASK_BUFFER_SIZE];
while let Ok(n) = client.read(&mut buf).await { while let Ok(n) = reader.read(&mut buf).await {
if n == 0 { if n == 0 {
break; break;
} }

View File

@@ -7,14 +7,10 @@ use tokio::time::Instant;
use tracing::{debug, trace, warn, info}; use tracing::{debug, trace, warn, info};
use crate::error::Result; use crate::error::Result;
use crate::stats::Stats; use crate::stats::Stats;
use crate::stream::BufferPool;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
// CHANGED: Reduced from 128KB to 16KB to match TLS record size and prevent bufferbloat.
// This is critical for iOS clients to maintain proper TCP flow control during uploads.
const BUFFER_SIZE: usize = 16384;
// Activity timeout for iOS compatibility (30 minutes) // Activity timeout for iOS compatibility (30 minutes)
// iOS does not support TCP_USER_TIMEOUT, so we implement application-level timeout
const ACTIVITY_TIMEOUT_SECS: u64 = 1800; const ACTIVITY_TIMEOUT_SECS: u64 = 1800;
/// Relay data bidirectionally between client and server /// Relay data bidirectionally between client and server
@@ -25,6 +21,7 @@ pub async fn relay_bidirectional<CR, CW, SR, SW>(
mut server_writer: SW, mut server_writer: SW,
user: &str, user: &str,
stats: Arc<Stats>, stats: Arc<Stats>,
buffer_pool: Arc<BufferPool>,
) -> Result<()> ) -> Result<()>
where where
CR: AsyncRead + Unpin + Send + 'static, CR: AsyncRead + Unpin + Send + 'static,
@@ -35,7 +32,6 @@ where
let user_c2s = user.to_string(); let user_c2s = user.to_string();
let user_s2c = user.to_string(); let user_s2c = user.to_string();
// Используем Arc::clone вместо stats.clone()
let stats_c2s = Arc::clone(&stats); let stats_c2s = Arc::clone(&stats);
let stats_s2c = Arc::clone(&stats); let stats_s2c = Arc::clone(&stats);
@@ -44,26 +40,29 @@ where
let c2s_bytes_clone = Arc::clone(&c2s_bytes); let c2s_bytes_clone = Arc::clone(&c2s_bytes);
let s2c_bytes_clone = Arc::clone(&s2c_bytes); let s2c_bytes_clone = Arc::clone(&s2c_bytes);
// Activity timeout for iOS compatibility
let activity_timeout = Duration::from_secs(ACTIVITY_TIMEOUT_SECS); let activity_timeout = Duration::from_secs(ACTIVITY_TIMEOUT_SECS);
// Client -> Server task with activity timeout let pool_c2s = buffer_pool.clone();
let pool_s2c = buffer_pool.clone();
// Client -> Server task
let c2s = tokio::spawn(async move { let c2s = tokio::spawn(async move {
let mut buf = vec![0u8; BUFFER_SIZE]; // Get buffer from pool
let mut buf = pool_c2s.get();
let mut total_bytes = 0u64; let mut total_bytes = 0u64;
let mut prev_total_bytes = 0u64;
let mut msg_count = 0u64; let mut msg_count = 0u64;
let mut last_activity = Instant::now(); let mut last_activity = Instant::now();
let mut last_log = Instant::now(); let mut last_log = Instant::now();
loop { loop {
// Read with timeout to prevent infinite hang on iOS // Read with timeout
let read_result = tokio::time::timeout( let read_result = tokio::time::timeout(
activity_timeout, activity_timeout,
client_reader.read(&mut buf) client_reader.read(&mut buf)
).await; ).await;
match read_result { match read_result {
// Timeout - no activity for too long
Err(_) => { Err(_) => {
warn!( warn!(
user = %user_c2s, user = %user_c2s,
@@ -76,7 +75,6 @@ where
break; break;
} }
// Read successful
Ok(Ok(0)) => { Ok(Ok(0)) => {
debug!( debug!(
user = %user_c2s, user = %user_c2s,
@@ -101,21 +99,26 @@ where
user = %user_c2s, user = %user_c2s,
bytes = n, bytes = n,
total = total_bytes, total = total_bytes,
data_preview = %hex::encode(&buf[..n.min(32)]),
"C->S data" "C->S data"
); );
// Log activity every 10 seconds for large transfers // Log activity every 10 seconds with correct rate
if last_log.elapsed() > Duration::from_secs(10) { let elapsed = last_log.elapsed();
let rate = total_bytes as f64 / last_log.elapsed().as_secs_f64(); if elapsed > Duration::from_secs(10) {
info!( let delta = total_bytes - prev_total_bytes;
let rate = delta as f64 / elapsed.as_secs_f64();
// Changed to DEBUG to reduce log spam
debug!(
user = %user_c2s, user = %user_c2s,
total_bytes = total_bytes, total_bytes = total_bytes,
msgs = msg_count, msgs = msg_count,
rate_kbps = (rate / 1024.0) as u64, rate_kbps = (rate / 1024.0) as u64,
"C->S transfer in progress" "C->S transfer in progress"
); );
last_log = Instant::now(); last_log = Instant::now();
prev_total_bytes = total_bytes;
} }
if let Err(e) = server_writer.write_all(&buf[..n]).await { if let Err(e) = server_writer.write_all(&buf[..n]).await {
@@ -136,23 +139,23 @@ where
} }
}); });
// Server -> Client task with activity timeout // Server -> Client task
let s2c = tokio::spawn(async move { let s2c = tokio::spawn(async move {
let mut buf = vec![0u8; BUFFER_SIZE]; // Get buffer from pool
let mut buf = pool_s2c.get();
let mut total_bytes = 0u64; let mut total_bytes = 0u64;
let mut prev_total_bytes = 0u64;
let mut msg_count = 0u64; let mut msg_count = 0u64;
let mut last_activity = Instant::now(); let mut last_activity = Instant::now();
let mut last_log = Instant::now(); let mut last_log = Instant::now();
loop { loop {
// Read with timeout to prevent infinite hang on iOS
let read_result = tokio::time::timeout( let read_result = tokio::time::timeout(
activity_timeout, activity_timeout,
server_reader.read(&mut buf) server_reader.read(&mut buf)
).await; ).await;
match read_result { match read_result {
// Timeout - no activity for too long
Err(_) => { Err(_) => {
warn!( warn!(
user = %user_s2c, user = %user_s2c,
@@ -165,7 +168,6 @@ where
break; break;
} }
// Read successful
Ok(Ok(0)) => { Ok(Ok(0)) => {
debug!( debug!(
user = %user_s2c, user = %user_s2c,
@@ -190,21 +192,25 @@ where
user = %user_s2c, user = %user_s2c,
bytes = n, bytes = n,
total = total_bytes, total = total_bytes,
data_preview = %hex::encode(&buf[..n.min(32)]),
"S->C data" "S->C data"
); );
// Log activity every 10 seconds for large transfers let elapsed = last_log.elapsed();
if last_log.elapsed() > Duration::from_secs(10) { if elapsed > Duration::from_secs(10) {
let rate = total_bytes as f64 / last_log.elapsed().as_secs_f64(); let delta = total_bytes - prev_total_bytes;
info!( let rate = delta as f64 / elapsed.as_secs_f64();
// Changed to DEBUG to reduce log spam
debug!(
user = %user_s2c, user = %user_s2c,
total_bytes = total_bytes, total_bytes = total_bytes,
msgs = msg_count, msgs = msg_count,
rate_kbps = (rate / 1024.0) as u64, rate_kbps = (rate / 1024.0) as u64,
"S->C transfer in progress" "S->C transfer in progress"
); );
last_log = Instant::now(); last_log = Instant::now();
prev_total_bytes = total_bytes;
} }
if let Err(e) = client_writer.write_all(&buf[..n]).await { if let Err(e) = client_writer.write_all(&buf[..n]).await {

View File

@@ -4,9 +4,11 @@ use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use dashmap::DashMap; use dashmap::DashMap;
use parking_lot::RwLock; use parking_lot::{RwLock, Mutex};
use lru::LruCache; use lru::LruCache;
use std::num::NonZeroUsize; use std::num::NonZeroUsize;
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
/// Thread-safe statistics /// Thread-safe statistics
#[derive(Default)] #[derive(Default)]
@@ -141,37 +143,57 @@ impl Stats {
} }
} }
// Arc<Stats> Hightech Stats :D /// Sharded Replay attack checker using LRU cache
/// Uses multiple independent LRU caches to reduce lock contention
/// Replay attack checker using LRU cache
pub struct ReplayChecker { pub struct ReplayChecker {
handshakes: RwLock<LruCache<Vec<u8>, ()>>, shards: Vec<Mutex<LruCache<Vec<u8>, ()>>>,
tls_digests: RwLock<LruCache<Vec<u8>, ()>>, shard_mask: usize,
} }
impl ReplayChecker { impl ReplayChecker {
pub fn new(capacity: usize) -> Self { /// Create new replay checker with specified capacity per shard
let cap = NonZeroUsize::new(capacity.max(1)).unwrap(); /// Total capacity = capacity * num_shards
Self { pub fn new(total_capacity: usize) -> Self {
handshakes: RwLock::new(LruCache::new(cap)), // Use 64 shards for good concurrency
tls_digests: RwLock::new(LruCache::new(cap)), let num_shards = 64;
let shard_capacity = (total_capacity / num_shards).max(1);
let cap = NonZeroUsize::new(shard_capacity).unwrap();
let mut shards = Vec::with_capacity(num_shards);
for _ in 0..num_shards {
shards.push(Mutex::new(LruCache::new(cap)));
} }
Self {
shards,
shard_mask: num_shards - 1,
}
}
fn get_shard(&self, key: &[u8]) -> usize {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
(hasher.finish() as usize) & self.shard_mask
} }
pub fn check_handshake(&self, data: &[u8]) -> bool { pub fn check_handshake(&self, data: &[u8]) -> bool {
self.handshakes.read().contains(&data.to_vec()) let shard_idx = self.get_shard(data);
self.shards[shard_idx].lock().contains(&data.to_vec())
} }
pub fn add_handshake(&self, data: &[u8]) { pub fn add_handshake(&self, data: &[u8]) {
self.handshakes.write().put(data.to_vec(), ()); let shard_idx = self.get_shard(data);
self.shards[shard_idx].lock().put(data.to_vec(), ());
} }
pub fn check_tls_digest(&self, data: &[u8]) -> bool { pub fn check_tls_digest(&self, data: &[u8]) -> bool {
self.tls_digests.read().contains(&data.to_vec()) let shard_idx = self.get_shard(data);
self.shards[shard_idx].lock().contains(&data.to_vec())
} }
pub fn add_tls_digest(&self, data: &[u8]) { pub fn add_tls_digest(&self, data: &[u8]) {
self.tls_digests.write().put(data.to_vec(), ()); let shard_idx = self.get_shard(data);
self.shards[shard_idx].lock().put(data.to_vec(), ());
} }
} }
@@ -183,7 +205,6 @@ mod tests {
fn test_stats_shared_counters() { fn test_stats_shared_counters() {
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
// Симулируем использование из разных "задач"
let stats1 = Arc::clone(&stats); let stats1 = Arc::clone(&stats);
let stats2 = Arc::clone(&stats); let stats2 = Arc::clone(&stats);
@@ -191,33 +212,20 @@ mod tests {
stats2.increment_connects_all(); stats2.increment_connects_all();
stats1.increment_connects_all(); stats1.increment_connects_all();
// Все инкременты должны быть видны
assert_eq!(stats.get_connects_all(), 3); assert_eq!(stats.get_connects_all(), 3);
} }
#[test] #[test]
fn test_user_stats_shared() { fn test_replay_checker_sharding() {
let stats = Arc::new(Stats::new()); let checker = ReplayChecker::new(100);
let data1 = b"test1";
let data2 = b"test2";
let stats1 = Arc::clone(&stats); checker.add_handshake(data1);
let stats2 = Arc::clone(&stats); assert!(checker.check_handshake(data1));
assert!(!checker.check_handshake(data2));
stats1.add_user_octets_from("user1", 100); checker.add_handshake(data2);
stats2.add_user_octets_from("user1", 200); assert!(checker.check_handshake(data2));
stats1.add_user_octets_to("user1", 50);
assert_eq!(stats.get_user_total_octets("user1"), 350);
}
#[test]
fn test_concurrent_user_connects() {
let stats = Arc::new(Stats::new());
stats.increment_user_curr_connects("user1");
stats.increment_user_curr_connects("user1");
assert_eq!(stats.get_user_curr_connects("user1"), 2);
stats.decrement_user_curr_connects("user1");
assert_eq!(stats.get_user_curr_connects("user1"), 1);
} }
} }

View File

@@ -45,6 +45,11 @@
//! - when upstream is Pending but pending still has room: accept `to_accept` bytes and //! - when upstream is Pending but pending still has room: accept `to_accept` bytes and
//! encrypt+append ciphertext directly into pending (in-place encryption of appended range) //! encrypt+append ciphertext directly into pending (in-place encryption of appended range)
//! Encrypted stream wrappers using AES-CTR
//!
//! This module provides stateful async stream wrappers that handle
//! encryption/decryption with proper partial read/write handling.
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use std::io::{self, ErrorKind, Result}; use std::io::{self, ErrorKind, Result};
use std::pin::Pin; use std::pin::Pin;
@@ -58,8 +63,9 @@ use super::state::{StreamState, YieldBuffer};
// ============= Constants ============= // ============= Constants =============
/// Maximum size for pending ciphertext buffer (bounded backpressure). /// Maximum size for pending ciphertext buffer (bounded backpressure).
/// 512 KiB tends to work well for mobile networks and avoids huge latency spikes. /// Reduced to 64KB to prevent bufferbloat on mobile networks.
const MAX_PENDING_WRITE: usize = 524_288; /// 512KB was causing high latency on 3G/LTE connections.
const MAX_PENDING_WRITE: usize = 64 * 1024;
/// Default read buffer capacity (reader mostly decrypts in-place into caller buffer). /// Default read buffer capacity (reader mostly decrypts in-place into caller buffer).
const DEFAULT_READ_CAPACITY: usize = 16 * 1024; const DEFAULT_READ_CAPACITY: usize = 16 * 1024;
@@ -99,22 +105,6 @@ impl StreamState for CryptoReaderState {
// ============= CryptoReader ============= // ============= CryptoReader =============
/// Reader that decrypts data using AES-CTR with proper state machine. /// Reader that decrypts data using AES-CTR with proper state machine.
///
/// This reader handles partial reads correctly by maintaining internal state
/// and never losing any data that has been read from upstream.
///
/// # State Machine
///
/// ┌──────────┐ read ┌──────────┐
/// │ Idle │ ------------> │ Yielding │
/// │ │ <------------ │ │
/// └──────────┘ drained └──────────┘
/// │ │
/// │ errors │
/// ▼ ▼
/// ┌──────────────────────────────────────┐
/// │ Poisoned │
/// └──────────────────────────────────────┘
pub struct CryptoReader<R> { pub struct CryptoReader<R> {
upstream: R, upstream: R,
decryptor: AesCtr, decryptor: AesCtr,
@@ -315,10 +305,6 @@ impl<R: AsyncRead + Unpin> CryptoReader<R> {
// ============= Pending Ciphertext ============= // ============= Pending Ciphertext =============
/// Pending ciphertext buffer with explicit position and strict max size. /// Pending ciphertext buffer with explicit position and strict max size.
///
/// - append plaintext then encrypt appended range in-place - one-touch copy, no extra Vec
/// - move ciphertext from scratch into pending without copying
/// - explicit compaction behavior for long-lived connections
#[derive(Debug)] #[derive(Debug)]
struct PendingCiphertext { struct PendingCiphertext {
buf: BytesMut, buf: BytesMut,
@@ -361,15 +347,13 @@ impl PendingCiphertext {
} }
// Compact when a large prefix was consumed. // Compact when a large prefix was consumed.
if self.pos >= 32 * 1024 { if self.pos >= 16 * 1024 {
let _ = self.buf.split_to(self.pos); let _ = self.buf.split_to(self.pos);
self.pos = 0; self.pos = 0;
} }
} }
/// Replace the entire pending ciphertext by moving `src` in (swap, no copy). /// Replace the entire pending ciphertext by moving `src` in (swap, no copy).
///
/// Precondition: src.len() <= max_len.
fn replace_with(&mut self, mut src: BytesMut) { fn replace_with(&mut self, mut src: BytesMut) {
debug_assert!(src.len() <= self.max_len); debug_assert!(src.len() <= self.max_len);
@@ -381,12 +365,6 @@ impl PendingCiphertext {
} }
/// Append plaintext and encrypt appended range in-place. /// Append plaintext and encrypt appended range in-place.
///
/// This is the high-throughput buffering path:
/// - copy plaintext into pending buffer
/// - encrypt only the newly appended bytes
///
/// CTR state advances by exactly plaintext.len().
fn push_encrypted(&mut self, encryptor: &mut AesCtr, plaintext: &[u8]) -> Result<()> { fn push_encrypted(&mut self, encryptor: &mut AesCtr, plaintext: &[u8]) -> Result<()> {
if plaintext.is_empty() { if plaintext.is_empty() {
return Ok(()); return Ok(());
@@ -444,21 +422,10 @@ impl StreamState for CryptoWriterState {
// ============= CryptoWriter ============= // ============= CryptoWriter =============
/// Writer that encrypts data using AES-CTR with correct async semantics. /// Writer that encrypts data using AES-CTR with correct async semantics.
///
/// - CTR state advances exactly by the number of bytes we report as written
/// - If upstream blocks, ciphertext is buffered/bounded
/// - Backpressure is applied when buffer is full
pub struct CryptoWriter<W> { pub struct CryptoWriter<W> {
upstream: W, upstream: W,
encryptor: AesCtr, encryptor: AesCtr,
state: CryptoWriterState, state: CryptoWriterState,
/// Scratch ciphertext for fast "write-through" path.
///
/// Flow:
/// - encrypt plaintext into scratch
/// - try upstream write
/// - if Pending/partial: move remainder into pending without re-encrypting
scratch: BytesMut, scratch: BytesMut,
} }
@@ -531,9 +498,6 @@ impl<W> CryptoWriter<W> {
} }
/// Select how many plaintext bytes can be accepted in buffering path /// Select how many plaintext bytes can be accepted in buffering path
///
/// Requirement: worst case - upstream pending, must buffer all ciphertext
/// for the accepted bytes
fn select_to_accept_for_buffering(state: &CryptoWriterState, buf_len: usize) -> usize { fn select_to_accept_for_buffering(state: &CryptoWriterState, buf_len: usize) -> usize {
if buf_len == 0 { if buf_len == 0 {
return 0; return 0;
@@ -557,11 +521,6 @@ impl<W> CryptoWriter<W> {
impl<W: AsyncWrite + Unpin> CryptoWriter<W> { impl<W: AsyncWrite + Unpin> CryptoWriter<W> {
/// Flush as much pending ciphertext as possible /// Flush as much pending ciphertext as possible
///
/// Returns
/// - Ready(Ok(())) if all pending is flushed or was none
/// - Pending if upstream would block
/// - Ready(Err(_)) on error
fn poll_flush_pending(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> { fn poll_flush_pending(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
loop { loop {
match &mut self.state { match &mut self.state {
@@ -606,14 +565,6 @@ impl<W: AsyncWrite + Unpin> CryptoWriter<W> {
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(n)) => {
pending.advance(n); pending.advance(n);
trace!(
flushed = n,
pending_left = pending.pending_len(),
"CryptoWriter: flushed pending ciphertext"
);
// continue loop to flush more
continue; continue;
} }
} }
@@ -643,9 +594,6 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
} }
// 1) If we have pending ciphertext, prioritize flushing it // 1) If we have pending ciphertext, prioritize flushing it
// If upstream pending
// -> still accept some plaintext ONLY if we can buffer
// all ciphertext for the accepted portion - bounded
if matches!(this.state, CryptoWriterState::Flushing { .. }) { if matches!(this.state, CryptoWriterState::Flushing { .. }) {
match this.poll_flush_pending(cx) { match this.poll_flush_pending(cx) {
Poll::Ready(Ok(())) => { Poll::Ready(Ok(())) => {
@@ -654,8 +602,6 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => { Poll::Pending => {
// Upstream blocked. Apply ideal backpressure // Upstream blocked. Apply ideal backpressure
// - accept up to remaining pending capacity
// - if no capacity -> pending
let to_accept = let to_accept =
Self::select_to_accept_for_buffering(&this.state, buf.len()); Self::select_to_accept_for_buffering(&this.state, buf.len());
@@ -670,11 +616,10 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
let plaintext = &buf[..to_accept]; let plaintext = &buf[..to_accept];
// Disjoint borrows: borrow encryptor and state separately via a match // Disjoint borrows
let encryptor = &mut this.encryptor; let encryptor = &mut this.encryptor;
let pending = Self::ensure_pending(&mut this.state); let pending = Self::ensure_pending(&mut this.state);
// Should not WouldBlock because to_accept <= remaining_capacity
if let Err(e) = pending.push_encrypted(encryptor, plaintext) { if let Err(e) = pending.push_encrypted(encryptor, plaintext) {
if e.kind() == ErrorKind::WouldBlock { if e.kind() == ErrorKind::WouldBlock {
return Poll::Pending; return Poll::Pending;
@@ -682,13 +627,6 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
return Poll::Ready(Err(e)); return Poll::Ready(Err(e));
} }
trace!(
accepted = to_accept,
pending_len = pending.pending_len(),
pending_cap = pending.remaining_capacity(),
"CryptoWriter: upstream Pending, buffered ciphertext (accepted plaintext)"
);
return Poll::Ready(Ok(to_accept)); return Poll::Ready(Ok(to_accept));
} }
} }
@@ -697,9 +635,6 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
// 2) Fast path: pending empty -> write-through // 2) Fast path: pending empty -> write-through
debug_assert!(matches!(this.state, CryptoWriterState::Idle)); debug_assert!(matches!(this.state, CryptoWriterState::Idle));
// Worst-case buffering requirement
// - If upstream becomes pending -> buffer full ciphertext for accepted bytes
// -> accept at most MAX_PENDING_WRITE per poll_write call
let to_accept = buf.len().min(MAX_PENDING_WRITE); let to_accept = buf.len().min(MAX_PENDING_WRITE);
let plaintext = &buf[..to_accept]; let plaintext = &buf[..to_accept];
@@ -708,18 +643,11 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
match Pin::new(&mut this.upstream).poll_write(cx, &this.scratch) { match Pin::new(&mut this.upstream).poll_write(cx, &this.scratch) {
Poll::Pending => { Poll::Pending => {
// Upstream blocked: buffer FULL ciphertext for accepted bytes. // Upstream blocked: buffer FULL ciphertext for accepted bytes.
// Move scratch into pending without copying.
let ciphertext = std::mem::take(&mut this.scratch); let ciphertext = std::mem::take(&mut this.scratch);
let pending = Self::ensure_pending(&mut this.state); let pending = Self::ensure_pending(&mut this.state);
pending.replace_with(ciphertext); pending.replace_with(ciphertext);
trace!(
accepted = to_accept,
pending_len = pending.pending_len(),
"CryptoWriter: write-through got Pending, buffered full ciphertext"
);
Poll::Ready(Ok(to_accept)) Poll::Ready(Ok(to_accept))
} }
@@ -736,26 +664,11 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(n)) => {
if n == this.scratch.len() { if n == this.scratch.len() {
trace!(
accepted = to_accept,
ciphertext_len = this.scratch.len(),
"CryptoWriter: write-through wrote full ciphertext directly"
);
this.scratch.clear(); this.scratch.clear();
return Poll::Ready(Ok(to_accept)); return Poll::Ready(Ok(to_accept));
} }
// Partial upstream write of ciphertext: // Partial upstream write of ciphertext
// We accepted `to_accept` plaintext bytes, CTR already advanced for to_accept
// Must buffer the remainder ciphertext
warn!(
accepted = to_accept,
ciphertext_len = this.scratch.len(),
written_ciphertext = n,
"CryptoWriter: partial upstream write, buffering remainder"
);
// Split off remainder without copying
let remainder = this.scratch.split_off(n); let remainder = this.scratch.split_off(n);
this.scratch.clear(); this.scratch.clear();
@@ -788,7 +701,6 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
let this = self.get_mut(); let this = self.get_mut();
// Best-effort flush pending ciphertext before shutdown // Best-effort flush pending ciphertext before shutdown
// If upstream blocks, proceed to shutdown anyway
match this.poll_flush_pending(cx) { match this.poll_flush_pending(cx) {
Poll::Pending => { Poll::Pending => {
debug!( debug!(
@@ -807,9 +719,6 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
// ============= PassthroughStream ============= // ============= PassthroughStream =============
/// Passthrough stream for fast mode - no encryption/decryption /// Passthrough stream for fast mode - no encryption/decryption
///
/// Used when keys are set up so that client and Telegram use the same
/// encryption, allowing data to pass through without re-encryption
pub struct PassthroughStream<S> { pub struct PassthroughStream<S> {
inner: S, inner: S,
} }