14 Commits

Author SHA1 Message Date
Alexey
3881ba9bed 1.1.1.0 2026-01-20 02:09:56 +03:00
Alexey
5ac9089ccb Update README.md 2026-01-20 01:39:59 +03:00
Alexey
eb8b991818 Update README.md 2026-01-20 01:32:39 +03:00
Alexey
2ce8fbb2cc 1.1.0.0 2026-01-20 01:20:02 +03:00
Alexey
038f0cd5d1 Update README.md 2026-01-19 23:52:31 +03:00
Alexey
efea3f981d Update README.md 2026-01-19 23:51:43 +03:00
Alexey
42ce9dd671 Update README.md 2026-01-12 22:11:21 +03:00
Alexey
4fa6867056 Merge pull request #7 from telemt/1.0.3.0
1.0.3.0
2026-01-12 00:49:31 +03:00
Alexey
54ea6efdd0 Global rewrite of AES-CTR + Upstream Pending + to_accept selection 2026-01-12 00:46:51 +03:00
brekotis
27ac32a901 Fixes in TLS for iOS 2026-01-12 00:32:42 +03:00
Alexey
829f53c123 Fixes for iOS 2026-01-11 22:59:51 +03:00
Alexey
43eae6127d Update README.md 2026-01-10 22:17:03 +03:00
Alexey
a03212c8cc Update README.md 2026-01-10 22:15:02 +03:00
Alexey
2613969a7c Update rust.yml 2026-01-09 23:15:52 +03:00
16 changed files with 1449 additions and 1546 deletions

View File

@@ -10,8 +10,8 @@ env:
CARGO_TERM_COLOR: always CARGO_TERM_COLOR: always
jobs: jobs:
build-and-test: build:
name: Build & Test name: Build
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:

174
README.md
View File

@@ -16,6 +16,7 @@
- [IP](#bind-on-ip) - [IP](#bind-on-ip)
- [SOCKS](#socks45-as-upstream) - [SOCKS](#socks45-as-upstream)
- [FAQ](#faq) - [FAQ](#faq)
- [Recognizability for DPI + crawler](#recognizability-for-dpi-and-crawler)
- [Telegram Calls](#telegram-calls-via-mtproxy) - [Telegram Calls](#telegram-calls-via-mtproxy)
- [DPI](#how-does-dpi-see-mtproxy-tls) - [DPI](#how-does-dpi-see-mtproxy-tls)
- [Whitelist on Network Level](#whitelist-on-ip) - [Whitelist on Network Level](#whitelist-on-ip)
@@ -118,44 +119,100 @@ then Ctrl+X -> Y -> Enter to save
## Configuration ## Configuration
### Minimal Configuration for First Start ### Minimal Configuration for First Start
```toml ```toml
port = 443 # Listening port # === General Settings ===
show_links = ["tele", "hello"] # Specify users, for whom will be displayed the links [general]
prefer_ipv6 = false
fast_mode = true
use_middle_proxy = false
# ad_tag = "..."
[users] [general.modes]
tele = "00000000000000000000000000000000" # Replace the secret with one generated before classic = false
hello = "00000000000000000000000000000000" # Replace the secret with one generated before secure = false
tls = true
[modes] # === Server Binding ===
classic = false # Plain obfuscated mode [server]
secure = false # dd-prefix mode port = 443
tls = true # Fake TLS - ee-prefix listen_addr_ipv4 = "0.0.0.0"
listen_addr_ipv6 = "::"
# metrics_port = 9090
# metrics_whitelist = ["127.0.0.1", "::1"]
tls_domain = "petrovich.ru" # Domain for ee-secret and masking # Listen on multiple interfaces/IPs (overrides listen_addr_*)
mask = true # Enable masking of bad traffic [[server.listeners]]
mask_host = "petrovich.ru" # Optional override for mask destination ip = "0.0.0.0"
mask_port = 443 # Port for masking # announce_ip = "1.2.3.4" # Optional: Public IP for tg:// links
prefer_ipv6 = false # Try IPv6 DCs first if true [[server.listeners]]
fast_mode = true # Use "fast" obfuscation variant ip = "::"
client_keepalive = 600 # Seconds # === Timeouts (in seconds) ===
client_ack_timeout = 300 # Seconds [timeouts]
client_handshake = 15
tg_connect = 10
client_keepalive = 60
client_ack = 300
# === Anti-Censorship & Masking ===
[censorship]
tls_domain = "petrovich.ru"
mask = true
mask_port = 443
# mask_host = "petrovich.ru" # Defaults to tls_domain if not set
fake_cert_len = 2048
# === Access Control & Users ===
# username "hello" is used for example
[access]
replay_check_len = 65536
ignore_time_skew = false
[access.users]
# format: "username" = "32_hex_chars_secret"
hello = "00000000000000000000000000000000"
# [access.user_max_tcp_conns]
# hello = 50
# [access.user_data_quota]
# hello = 1073741824 # 1 GB
# === Upstreams & Routing ===
# By default, direct connection is used, but you can add SOCKS proxy
# Direct - Default
[[upstreams]]
type = "direct"
enabled = true
weight = 10
# SOCKS5
# [[upstreams]]
# type = "socks5"
# address = "127.0.0.1:9050"
# enabled = false
# weight = 1
# === UI ===
# Users to show in the startup log (tg:// links)
show_link = ["hello"]
``` ```
### Advanced ### Advanced
#### Adtag #### Adtag
To use channel advertising and usage statistics from Telegram, get Adtag from [@mtproxybot](https://t.me/mtproxybot), add this parameter to the end of config.toml and specify it To use channel advertising and usage statistics from Telegram, get Adtag from [@mtproxybot](https://t.me/mtproxybot), add this parameter to section `[General]`
```toml ```toml
ad_tag = "00000000000000000000000000000000" # Replace zeros to your adtag from @mtproxybot ad_tag = "00000000000000000000000000000000" # Replace zeros to your adtag from @mtproxybot
``` ```
#### Listening and Announce IPs #### Listening and Announce IPs
To specify listening address and/or address in links, add to the end of config.toml: To specify listening address and/or address in links, add to section `[[server.listeners]]` of config.toml:
```toml ```toml
[[listeners]] [[server.listeners]]
ip = "0.0.0.0" # 0.0.0.0 = all IPs; your IP = specific listening ip = "0.0.0.0" # 0.0.0.0 = all IPs; your IP = specific listening
announce_ip = "1.2.3.4" # IP in links; comment with # if not used announce_ip = "1.2.3.4" # IP in links; comment with # if not used
``` ```
#### Upstream Manager #### Upstream Manager
To specify upstream, add to the end of config.toml: To specify upstream, add to section `[[upstreams]]` of config.toml:
##### Bind on IP ##### Bind on IP
```toml ```toml
[[upstreams]] [[upstreams]]
@@ -186,6 +243,77 @@ enabled = true
``` ```
## FAQ ## FAQ
### Recognizability for DPI and crawler
Since version 1.1, we have debugged masking perfectly, for all clients without "presenting" a key,
we transparently direct traffic to the target host.
- We consider this a breakthrough aspect, which few people managed to achieve in perfect form.
- Based on this: if configured correctly, **TLS mode is completely identical to real-life handshake + communication** with a specified host:
```bash
root@debian:~/telemt# curl -v -I --resolve petrovich.ru:443:212.220.88.77 https://petrovich.ru/
* Added petrovich.ru:443:212.220.88.77 to DNS cache
* Hostname petrovich.ru was found in DNS cache
* Trying 212.220.88.77:443...
* Connected to petrovich.ru (212.220.88.77) port 443 (#0)
* ALPN: offers h2,http/1.1
* TLSv1.3 (OUT), TLS handshake, Client hello (1):
* CAfile: /etc/ssl/certs/ca-certificates.crt
* CApath: /etc/ssl/certs
* TLSv1.3 (IN), TLS handshake, Server hello (2):
* TLSv1.3 (IN), TLS handshake, Encrypted Extensions (8):
* TLSv1.3 (IN), TLS handshake, Certificate (11):
* TLSv1.3 (IN), TLS handshake, CERT verify (15):
* TLSv1.3 (IN), TLS handshake, Finished (20):
* TLSv1.3 (OUT), TLS change cipher, Change cipher spec (1):
* TLSv1.3 (OUT), TLS handshake, Finished (20):
* SSL connection using TLSv1.3 / TLS_AES_256_GCM_SHA384
* ALPN: server did not agree on a protocol. Uses default.
* Server certificate:
* subject: C=RU; ST=Saint Petersburg; L=Saint Petersburg; O=STD Petrovich; CN=*.petrovich.ru
* start date: Jan 28 11:21:01 2025 GMT
* expire date: Mar 1 11:21:00 2026 GMT
* subjectAltName: host "petrovich.ru" matched cert's "petrovich.ru"
* issuer: C=BE; O=GlobalSign nv-sa; CN=GlobalSign RSA OV SSL CA 2018
* SSL certificate verify ok.
* using HTTP/1.x
> HEAD / HTTP/1.1
> Host: petrovich.ru
> User-Agent: curl/7.88.1
> Accept: */*
>
* TLSv1.3 (IN), TLS handshake, Newsession Ticket (4):
* TLSv1.3 (IN), TLS handshake, Newsession Ticket (4):
* old SSL session ID is stale, removing
< HTTP/1.1 200 OK
HTTP/1.1 200 OK
< Server: Variti/0.9.3a
Server: Variti/0.9.3a
< Date: Thu, 01 Jan 2026 00:0000 GMT
Date: Thu, 01 Jan 2026 00:0000 GMT
< Access-Control-Allow-Origin: *
Access-Control-Allow-Origin: *
< Content-Type: text/html
Content-Type: text/html
< Cache-Control: no-store
Cache-Control: no-store
< Expires: Thu, 01 Jan 2026 00:0000 GMT
Expires: Thu, 01 Jan 2026 00:0000 GMT
< Pragma: no-cache
Pragma: no-cache
< Set-Cookie: ipp_uid=XXXXX/XXXXX/XXXXX==; Expires=Tue, 31 Dec 2040 23:59:59 GMT; Domain=.petrovich.ru; Path=/
Set-Cookie: ipp_uid=XXXXX/XXXXX/XXXXX==; Expires=Tue, 31 Dec 2040 23:59:59 GMT; Domain=.petrovich.ru; Path=/
< Content-Type: text/html
Content-Type: text/html
< Content-Length: 31253
Content-Length: 31253
< Connection: keep-alive
Connection: keep-alive
< Keep-Alive: timeout=60
Keep-Alive: timeout=60
<
* Connection #0 to host petrovich.ru left intact
```
### Telegram Calls via MTProxy ### Telegram Calls via MTProxy
- Telegram architecture **does NOT allow calls via MTProxy**, but only via SOCKS5, which cannot be obfuscated - Telegram architecture **does NOT allow calls via MTProxy**, but only via SOCKS5, which cannot be obfuscated
### How does DPI see MTProxy TLS? ### How does DPI see MTProxy TLS?
@@ -231,6 +359,10 @@ telemt config.toml
- Memory safety and reduced attack surface - Memory safety and reduced attack surface
- Tokio's asynchronous architecture - Tokio's asynchronous architecture
## Issues
- ✅ [SOCKS5 as Upstream](https://github.com/telemt/telemt/issues/1) -> added Upstream Management
- ✅ [iOS - Media Upload Hanging-in-Loop](https://github.com/telemt/telemt/issues/2)
## Roadmap ## Roadmap
- Public IP in links - Public IP in links
- Config Reload-on-fly - Config Reload-on-fly

View File

@@ -1,13 +1,78 @@
port = 443 # === General Settings ===
[general]
prefer_ipv6 = false
fast_mode = true
use_middle_proxy = false
# ad_tag = "..."
[users] [general.modes]
user1 = "00000000000000000000000000000000" classic = false
secure = false
[modes]
classic = true
secure = true
tls = true tls = true
tls_domain = "www.github.com" # === Server Binding ===
fast_mode = true [server]
prefer_ipv6 = false port = 443
listen_addr_ipv4 = "0.0.0.0"
listen_addr_ipv6 = "::"
# metrics_port = 9090
# metrics_whitelist = ["127.0.0.1", "::1"]
# Listen on multiple interfaces/IPs (overrides listen_addr_*)
[[server.listeners]]
ip = "0.0.0.0"
# announce_ip = "1.2.3.4" # Optional: Public IP for tg:// links
[[server.listeners]]
ip = "::"
# === Timeouts (in seconds) ===
[timeouts]
client_handshake = 15
tg_connect = 10
client_keepalive = 60
client_ack = 300
# === Anti-Censorship & Masking ===
[censorship]
tls_domain = "petrovich.ru"
mask = true
mask_port = 443
# mask_host = "petrovich.ru" # Defaults to tls_domain if not set
fake_cert_len = 2048
# === Access Control & Users ===
# username "hello" is used for example
[access]
replay_check_len = 65536
ignore_time_skew = false
[access.users]
# format: "username" = "32_hex_chars_secret"
hello = "00000000000000000000000000000000"
# [access.user_max_tcp_conns]
# hello = 50
# [access.user_data_quota]
# hello = 1073741824 # 1 GB
# === Upstreams & Routing ===
# By default, direct connection is used, but you can add SOCKS proxy
# Direct - Default
[[upstreams]]
type = "direct"
enabled = true
weight = 10
# SOCKS5
# [[upstreams]]
# type = "socks5"
# address = "127.0.0.1:9050"
# enabled = false
# weight = 1
# === UI ===
# Users to show in the startup log (tg:// links)
show_link = ["hello"]

View File

@@ -7,6 +7,29 @@ use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
// ============= Helper Defaults =============
fn default_true() -> bool { true }
fn default_port() -> u16 { 443 }
fn default_tls_domain() -> String { "www.google.com".to_string() }
fn default_mask_port() -> u16 { 443 }
fn default_replay_check_len() -> usize { 65536 }
fn default_handshake_timeout() -> u64 { 15 }
fn default_connect_timeout() -> u64 { 10 }
fn default_keepalive() -> u64 { 60 }
fn default_ack_timeout() -> u64 { 300 }
fn default_listen_addr() -> String { "0.0.0.0".to_string() }
fn default_fake_cert_len() -> usize { 2048 }
fn default_weight() -> u16 { 1 }
fn default_metrics_whitelist() -> Vec<IpAddr> {
vec![
"127.0.0.1".parse().unwrap(),
"::1".parse().unwrap(),
]
}
// ============= Sub-Configs =============
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProxyModes { pub struct ProxyModes {
#[serde(default)] #[serde(default)]
@@ -17,26 +40,185 @@ pub struct ProxyModes {
pub tls: bool, pub tls: bool,
} }
fn default_true() -> bool { true }
fn default_weight() -> u16 { 1 }
impl Default for ProxyModes { impl Default for ProxyModes {
fn default() -> Self { fn default() -> Self {
Self { classic: true, secure: true, tls: true } Self { classic: true, secure: true, tls: true }
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneralConfig {
#[serde(default)]
pub modes: ProxyModes,
#[serde(default)]
pub prefer_ipv6: bool,
#[serde(default = "default_true")]
pub fast_mode: bool,
#[serde(default)]
pub use_middle_proxy: bool,
#[serde(default)]
pub ad_tag: Option<String>,
}
impl Default for GeneralConfig {
fn default() -> Self {
Self {
modes: ProxyModes::default(),
prefer_ipv6: false,
fast_mode: true,
use_middle_proxy: false,
ad_tag: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
#[serde(default = "default_port")]
pub port: u16,
#[serde(default = "default_listen_addr")]
pub listen_addr_ipv4: String,
#[serde(default)]
pub listen_addr_ipv6: Option<String>,
#[serde(default)]
pub listen_unix_sock: Option<String>,
#[serde(default)]
pub metrics_port: Option<u16>,
#[serde(default = "default_metrics_whitelist")]
pub metrics_whitelist: Vec<IpAddr>,
#[serde(default)]
pub listeners: Vec<ListenerConfig>,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
port: default_port(),
listen_addr_ipv4: default_listen_addr(),
listen_addr_ipv6: Some("::".to_string()),
listen_unix_sock: None,
metrics_port: None,
metrics_whitelist: default_metrics_whitelist(),
listeners: Vec::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TimeoutsConfig {
#[serde(default = "default_handshake_timeout")]
pub client_handshake: u64,
#[serde(default = "default_connect_timeout")]
pub tg_connect: u64,
#[serde(default = "default_keepalive")]
pub client_keepalive: u64,
#[serde(default = "default_ack_timeout")]
pub client_ack: u64,
}
impl Default for TimeoutsConfig {
fn default() -> Self {
Self {
client_handshake: default_handshake_timeout(),
tg_connect: default_connect_timeout(),
client_keepalive: default_keepalive(),
client_ack: default_ack_timeout(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AntiCensorshipConfig {
#[serde(default = "default_tls_domain")]
pub tls_domain: String,
#[serde(default = "default_true")]
pub mask: bool,
#[serde(default)]
pub mask_host: Option<String>,
#[serde(default = "default_mask_port")]
pub mask_port: u16,
#[serde(default = "default_fake_cert_len")]
pub fake_cert_len: usize,
}
impl Default for AntiCensorshipConfig {
fn default() -> Self {
Self {
tls_domain: default_tls_domain(),
mask: true,
mask_host: None,
mask_port: default_mask_port(),
fake_cert_len: default_fake_cert_len(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessConfig {
#[serde(default)]
pub users: HashMap<String, String>,
#[serde(default)]
pub user_max_tcp_conns: HashMap<String, usize>,
#[serde(default)]
pub user_expirations: HashMap<String, DateTime<Utc>>,
#[serde(default)]
pub user_data_quota: HashMap<String, u64>,
#[serde(default = "default_replay_check_len")]
pub replay_check_len: usize,
#[serde(default)]
pub ignore_time_skew: bool,
}
impl Default for AccessConfig {
fn default() -> Self {
let mut users = HashMap::new();
users.insert("default".to_string(), "00000000000000000000000000000000".to_string());
Self {
users,
user_max_tcp_conns: HashMap::new(),
user_expirations: HashMap::new(),
user_data_quota: HashMap::new(),
replay_check_len: default_replay_check_len(),
ignore_time_skew: false,
}
}
}
// ============= Aux Structures =============
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "lowercase")] #[serde(tag = "type", rename_all = "lowercase")]
pub enum UpstreamType { pub enum UpstreamType {
Direct { Direct {
#[serde(default)] #[serde(default)]
interface: Option<String>, // Bind to specific IP/Interface interface: Option<String>,
}, },
Socks4 { Socks4 {
address: String, // IP:Port of SOCKS server address: String,
#[serde(default)] #[serde(default)]
interface: Option<String>, // Bind to specific IP/Interface for connection to SOCKS interface: Option<String>,
#[serde(default)] #[serde(default)]
user_id: Option<String>, user_id: Option<String>,
}, },
@@ -65,157 +247,35 @@ pub struct UpstreamConfig {
pub struct ListenerConfig { pub struct ListenerConfig {
pub ip: IpAddr, pub ip: IpAddr,
#[serde(default)] #[serde(default)]
pub announce_ip: Option<IpAddr>, // IP to show in tg:// links pub announce_ip: Option<IpAddr>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] // ============= Main Config =============
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ProxyConfig { pub struct ProxyConfig {
#[serde(default = "default_port")] #[serde(default)]
pub port: u16, pub general: GeneralConfig,
#[serde(default)] #[serde(default)]
pub users: HashMap<String, String>, pub server: ServerConfig,
#[serde(default)] #[serde(default)]
pub ad_tag: Option<String>, pub timeouts: TimeoutsConfig,
#[serde(default)] #[serde(default)]
pub modes: ProxyModes, pub censorship: AntiCensorshipConfig,
#[serde(default = "default_tls_domain")]
pub tls_domain: String,
#[serde(default = "default_true")]
pub mask: bool,
#[serde(default)] #[serde(default)]
pub mask_host: Option<String>, pub access: AccessConfig,
#[serde(default = "default_mask_port")]
pub mask_port: u16,
#[serde(default)]
pub prefer_ipv6: bool,
#[serde(default = "default_true")]
pub fast_mode: bool,
#[serde(default)]
pub use_middle_proxy: bool,
#[serde(default)]
pub user_max_tcp_conns: HashMap<String, usize>,
#[serde(default)]
pub user_expirations: HashMap<String, DateTime<Utc>>,
#[serde(default)]
pub user_data_quota: HashMap<String, u64>,
#[serde(default = "default_replay_check_len")]
pub replay_check_len: usize,
#[serde(default)]
pub ignore_time_skew: bool,
#[serde(default = "default_handshake_timeout")]
pub client_handshake_timeout: u64,
#[serde(default = "default_connect_timeout")]
pub tg_connect_timeout: u64,
#[serde(default = "default_keepalive")]
pub client_keepalive: u64,
#[serde(default = "default_ack_timeout")]
pub client_ack_timeout: u64,
#[serde(default = "default_listen_addr")]
pub listen_addr_ipv4: String,
#[serde(default)]
pub listen_addr_ipv6: Option<String>,
#[serde(default)]
pub listen_unix_sock: Option<String>,
#[serde(default)]
pub metrics_port: Option<u16>,
#[serde(default = "default_metrics_whitelist")]
pub metrics_whitelist: Vec<IpAddr>,
#[serde(default = "default_fake_cert_len")]
pub fake_cert_len: usize,
// New fields
#[serde(default)] #[serde(default)]
pub upstreams: Vec<UpstreamConfig>, pub upstreams: Vec<UpstreamConfig>,
#[serde(default)]
pub listeners: Vec<ListenerConfig>,
#[serde(default)] #[serde(default)]
pub show_link: Vec<String>, pub show_link: Vec<String>,
} }
fn default_port() -> u16 { 443 }
fn default_tls_domain() -> String { "www.google.com".to_string() }
fn default_mask_port() -> u16 { 443 }
fn default_replay_check_len() -> usize { 65536 }
fn default_handshake_timeout() -> u64 { 10 }
fn default_connect_timeout() -> u64 { 10 }
fn default_keepalive() -> u64 { 600 }
fn default_ack_timeout() -> u64 { 300 }
fn default_listen_addr() -> String { "0.0.0.0".to_string() }
fn default_fake_cert_len() -> usize { 2048 }
fn default_metrics_whitelist() -> Vec<IpAddr> {
vec![
"127.0.0.1".parse().unwrap(),
"::1".parse().unwrap(),
]
}
impl Default for ProxyConfig {
fn default() -> Self {
let mut users = HashMap::new();
users.insert("default".to_string(), "00000000000000000000000000000000".to_string());
Self {
port: default_port(),
users,
ad_tag: None,
modes: ProxyModes::default(),
tls_domain: default_tls_domain(),
mask: true,
mask_host: None,
mask_port: default_mask_port(),
prefer_ipv6: false,
fast_mode: true,
use_middle_proxy: false,
user_max_tcp_conns: HashMap::new(),
user_expirations: HashMap::new(),
user_data_quota: HashMap::new(),
replay_check_len: default_replay_check_len(),
ignore_time_skew: false,
client_handshake_timeout: default_handshake_timeout(),
tg_connect_timeout: default_connect_timeout(),
client_keepalive: default_keepalive(),
client_ack_timeout: default_ack_timeout(),
listen_addr_ipv4: default_listen_addr(),
listen_addr_ipv6: Some("::".to_string()),
listen_unix_sock: None,
metrics_port: None,
metrics_whitelist: default_metrics_whitelist(),
fake_cert_len: default_fake_cert_len(),
upstreams: Vec::new(),
listeners: Vec::new(),
show_link: Vec::new(),
}
}
}
impl ProxyConfig { impl ProxyConfig {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> { pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let content = std::fs::read_to_string(path) let content = std::fs::read_to_string(path)
@@ -225,7 +285,7 @@ impl ProxyConfig {
.map_err(|e| ProxyError::Config(e.to_string()))?; .map_err(|e| ProxyError::Config(e.to_string()))?;
// Validate secrets // Validate secrets
for (user, secret) in &config.users { for (user, secret) in &config.access.users {
if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 { if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 {
return Err(ProxyError::InvalidSecret { return Err(ProxyError::InvalidSecret {
user: user.clone(), user: user.clone(),
@@ -234,26 +294,37 @@ impl ProxyConfig {
} }
} }
// Default mask_host // Validate tls_domain
if config.mask_host.is_none() { if config.censorship.tls_domain.is_empty() {
config.mask_host = Some(config.tls_domain.clone()); return Err(ProxyError::Config("tls_domain cannot be empty".to_string()));
}
// Warn if using default tls_domain
if config.censorship.tls_domain == "www.google.com" {
tracing::warn!("Using default tls_domain (www.google.com). Consider setting a custom domain in config.toml");
}
// Default mask_host to tls_domain if not set
if config.censorship.mask_host.is_none() {
tracing::info!("mask_host not set, using tls_domain ({}) for masking", config.censorship.tls_domain);
config.censorship.mask_host = Some(config.censorship.tls_domain.clone());
} }
// Random fake_cert_len // Random fake_cert_len
use rand::Rng; use rand::Rng;
config.fake_cert_len = rand::thread_rng().gen_range(1024..4096); config.censorship.fake_cert_len = rand::thread_rng().gen_range(1024..4096);
// Migration: Populate listeners if empty // Migration: Populate listeners if empty
if config.listeners.is_empty() { if config.server.listeners.is_empty() {
if let Ok(ipv4) = config.listen_addr_ipv4.parse::<IpAddr>() { if let Ok(ipv4) = config.server.listen_addr_ipv4.parse::<IpAddr>() {
config.listeners.push(ListenerConfig { config.server.listeners.push(ListenerConfig {
ip: ipv4, ip: ipv4,
announce_ip: None, announce_ip: None,
}); });
} }
if let Some(ipv6_str) = &config.listen_addr_ipv6 { if let Some(ipv6_str) = &config.server.listen_addr_ipv6 {
if let Ok(ipv6) = ipv6_str.parse::<IpAddr>() { if let Ok(ipv6) = ipv6_str.parse::<IpAddr>() {
config.listeners.push(ListenerConfig { config.server.listeners.push(ListenerConfig {
ip: ipv6, ip: ipv6,
announce_ip: None, announce_ip: None,
}); });
@@ -274,14 +345,21 @@ impl ProxyConfig {
} }
pub fn validate(&self) -> Result<()> { pub fn validate(&self) -> Result<()> {
if self.users.is_empty() { if self.access.users.is_empty() {
return Err(ProxyError::Config("No users configured".to_string())); return Err(ProxyError::Config("No users configured".to_string()));
} }
if !self.modes.classic && !self.modes.secure && !self.modes.tls { if !self.general.modes.classic && !self.general.modes.secure && !self.general.modes.tls {
return Err(ProxyError::Config("No modes enabled".to_string())); return Err(ProxyError::Config("No modes enabled".to_string()));
} }
// Validate tls_domain format (basic check)
if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') {
return Err(ProxyError::Config(
format!("Invalid tls_domain: '{}'. Must be a valid domain name", self.censorship.tls_domain)
));
}
Ok(()) Ok(())
} }
} }

View File

@@ -297,16 +297,16 @@ pub type StreamResult<T> = std::result::Result<T, StreamError>;
/// Result with optional bad client handling /// Result with optional bad client handling
#[derive(Debug)] #[derive(Debug)]
pub enum HandshakeResult<T> { pub enum HandshakeResult<T, R, W> {
/// Handshake succeeded /// Handshake succeeded
Success(T), Success(T),
/// Client failed validation, needs masking /// Client failed validation, needs masking. Returns ownership of streams.
BadClient, BadClient { reader: R, writer: W },
/// Error occurred /// Error occurred
Error(ProxyError), Error(ProxyError),
} }
impl<T> HandshakeResult<T> { impl<T, R, W> HandshakeResult<T, R, W> {
/// Check if successful /// Check if successful
pub fn is_success(&self) -> bool { pub fn is_success(&self) -> bool {
matches!(self, HandshakeResult::Success(_)) matches!(self, HandshakeResult::Success(_))
@@ -314,49 +314,32 @@ impl<T> HandshakeResult<T> {
/// Check if bad client /// Check if bad client
pub fn is_bad_client(&self) -> bool { pub fn is_bad_client(&self) -> bool {
matches!(self, HandshakeResult::BadClient) matches!(self, HandshakeResult::BadClient { .. })
}
/// Convert to Result, treating BadClient as error
pub fn into_result(self) -> Result<T> {
match self {
HandshakeResult::Success(v) => Ok(v),
HandshakeResult::BadClient => Err(ProxyError::InvalidHandshake("Bad client".into())),
HandshakeResult::Error(e) => Err(e),
}
} }
/// Map the success value /// Map the success value
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> HandshakeResult<U> { pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> HandshakeResult<U, R, W> {
match self { match self {
HandshakeResult::Success(v) => HandshakeResult::Success(f(v)), HandshakeResult::Success(v) => HandshakeResult::Success(f(v)),
HandshakeResult::BadClient => HandshakeResult::BadClient, HandshakeResult::BadClient { reader, writer } => HandshakeResult::BadClient { reader, writer },
HandshakeResult::Error(e) => HandshakeResult::Error(e), HandshakeResult::Error(e) => HandshakeResult::Error(e),
} }
} }
/// Convert success to Option
pub fn ok(self) -> Option<T> {
match self {
HandshakeResult::Success(v) => Some(v),
_ => None,
}
}
} }
impl<T> From<ProxyError> for HandshakeResult<T> { impl<T, R, W> From<ProxyError> for HandshakeResult<T, R, W> {
fn from(err: ProxyError) -> Self { fn from(err: ProxyError) -> Self {
HandshakeResult::Error(err) HandshakeResult::Error(err)
} }
} }
impl<T> From<std::io::Error> for HandshakeResult<T> { impl<T, R, W> From<std::io::Error> for HandshakeResult<T, R, W> {
fn from(err: std::io::Error) -> Self { fn from(err: std::io::Error) -> Self {
HandshakeResult::Error(ProxyError::Io(err)) HandshakeResult::Error(ProxyError::Io(err))
} }
} }
impl<T> From<StreamError> for HandshakeResult<T> { impl<T, R, W> From<StreamError> for HandshakeResult<T, R, W> {
fn from(err: StreamError) -> Self { fn from(err: StreamError) -> Self {
HandshakeResult::Error(ProxyError::Stream(err)) HandshakeResult::Error(ProxyError::Stream(err))
} }

View File

@@ -20,9 +20,10 @@ mod util;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::proxy::ClientHandler; use crate::proxy::ClientHandler;
use crate::stats::Stats; use crate::stats::{Stats, ReplayChecker};
use crate::transport::{create_listener, ListenOptions, UpstreamManager}; use crate::transport::{create_listener, ListenOptions, UpstreamManager};
use crate::util::ip::detect_ip; use crate::util::ip::detect_ip;
use crate::stream::BufferPool;
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
@@ -52,12 +53,33 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
config.validate()?; config.validate()?;
// Log loaded configuration for debugging
info!("=== Configuration Loaded ===");
info!("TLS Domain: {}", config.censorship.tls_domain);
info!("Mask enabled: {}", config.censorship.mask);
info!("Mask host: {}", config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain));
info!("Mask port: {}", config.censorship.mask_port);
info!("Modes: classic={}, secure={}, tls={}",
config.general.modes.classic,
config.general.modes.secure,
config.general.modes.tls
);
info!("============================");
let config = Arc::new(config); let config = Arc::new(config);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
// Initialize global ReplayChecker
// Using sharded implementation for better concurrency
let replay_checker = Arc::new(ReplayChecker::new(config.access.replay_check_len));
// Initialize Upstream Manager // Initialize Upstream Manager
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone())); let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
// Initialize Buffer Pool
// 16KB buffers, max 4096 buffers (~64MB total cached)
let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096));
// Start Health Checks // Start Health Checks
let um_clone = upstream_manager.clone(); let um_clone = upstream_manager.clone();
tokio::spawn(async move { tokio::spawn(async move {
@@ -70,8 +92,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Start Listeners // Start Listeners
let mut listeners = Vec::new(); let mut listeners = Vec::new();
for listener_conf in &config.listeners { for listener_conf in &config.server.listeners {
let addr = SocketAddr::new(listener_conf.ip, config.port); let addr = SocketAddr::new(listener_conf.ip, config.server.port);
let options = ListenOptions { let options = ListenOptions {
ipv6_only: listener_conf.ip.is_ipv6(), ipv6_only: listener_conf.ip.is_ipv6(),
..Default::default() ..Default::default()
@@ -83,13 +105,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
info!("Listening on {}", addr); info!("Listening on {}", addr);
// Determine public IP for tg:// links // Determine public IP for tg:// links
// 1. Use explicit announce_ip if set
// 2. If listening on 0.0.0.0 or ::, use detected public IP
// 3. Otherwise use the bind IP
let public_ip = if let Some(ip) = listener_conf.announce_ip { let public_ip = if let Some(ip) = listener_conf.announce_ip {
ip ip
} else if listener_conf.ip.is_unspecified() { } else if listener_conf.ip.is_unspecified() {
// Try to use detected IP of the same family
if listener_conf.ip.is_ipv4() { if listener_conf.ip.is_ipv4() {
detected_ip.ipv4.unwrap_or(listener_conf.ip) detected_ip.ipv4.unwrap_or(listener_conf.ip)
} else { } else {
@@ -103,26 +121,23 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
if !config.show_link.is_empty() { if !config.show_link.is_empty() {
info!("--- Proxy Links for {} ---", public_ip); info!("--- Proxy Links for {} ---", public_ip);
for user_name in &config.show_link { for user_name in &config.show_link {
if let Some(secret) = config.users.get(user_name) { if let Some(secret) = config.access.users.get(user_name) {
info!("User: {}", user_name); info!("User: {}", user_name);
// Classic if config.general.modes.classic {
if config.modes.classic {
info!(" Classic: tg://proxy?server={}&port={}&secret={}", info!(" Classic: tg://proxy?server={}&port={}&secret={}",
public_ip, config.port, secret); public_ip, config.server.port, secret);
} }
// DD (Secure) if config.general.modes.secure {
if config.modes.secure {
info!(" DD: tg://proxy?server={}&port={}&secret=dd{}", info!(" DD: tg://proxy?server={}&port={}&secret=dd{}",
public_ip, config.port, secret); public_ip, config.server.port, secret);
} }
// EE-TLS (FakeTLS) if config.general.modes.tls {
if config.modes.tls { let domain_hex = hex::encode(&config.censorship.tls_domain);
let domain_hex = hex::encode(&config.tls_domain);
info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
public_ip, config.port, secret, domain_hex); public_ip, config.server.port, secret, domain_hex);
} }
} else { } else {
warn!("User '{}' specified in show_link not found in users list", user_name); warn!("User '{}' specified in show_link not found in users list", user_name);
@@ -145,13 +160,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
} }
// Accept loop // Accept loop
// For simplicity in this slice, we just spawn a task for each listener
// In a real high-perf scenario, we might want a more complex accept loop
for listener in listeners { for listener in listeners {
let config = config.clone(); let config = config.clone();
let stats = stats.clone(); let stats = stats.clone();
let upstream_manager = upstream_manager.clone(); let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
@@ -160,6 +174,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let config = config.clone(); let config = config.clone();
let stats = stats.clone(); let stats = stats.clone();
let upstream_manager = upstream_manager.clone(); let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = ClientHandler::new( if let Err(e) = ClientHandler::new(
@@ -167,10 +183,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
peer_addr, peer_addr,
config, config,
stats, stats,
upstream_manager upstream_manager,
replay_checker,
buffer_pool
).run().await { ).run().await {
// Log only relevant errors // Log only relevant errors
// debug!("Connection error: {}", e);
} }
}); });
} }

View File

@@ -167,7 +167,10 @@ pub const DEFAULT_ACK_TIMEOUT_SECS: u64 = 300;
// ============= Buffer Sizes ============= // ============= Buffer Sizes =============
/// Default buffer size /// Default buffer size
pub const DEFAULT_BUFFER_SIZE: usize = 65536; /// CHANGED: Reduced from 64KB to 16KB to match TLS record size and align with
/// the new buffering strategy for better iOS upload performance.
pub const DEFAULT_BUFFER_SIZE: usize = 16384;
/// Small buffer size for bad client handling /// Small buffer size for bad client handling
pub const SMALL_BUFFER_SIZE: usize = 8192; pub const SMALL_BUFFER_SIZE: usize = 8192;

View File

@@ -14,15 +14,16 @@ use crate::protocol::constants::*;
use crate::protocol::tls; use crate::protocol::tls;
use crate::stats::{Stats, ReplayChecker}; use crate::stats::{Stats, ReplayChecker};
use crate::transport::{configure_client_socket, UpstreamManager}; use crate::transport::{configure_client_socket, UpstreamManager};
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter}; use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter, BufferPool};
use crate::crypto::AesCtr; use crate::crypto::AesCtr;
use super::handshake::{ // Use absolute paths to avoid confusion
use crate::proxy::handshake::{
handle_tls_handshake, handle_mtproto_handshake, handle_tls_handshake, handle_mtproto_handshake,
HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce, HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce,
}; };
use super::relay::relay_bidirectional; use crate::proxy::relay::relay_bidirectional;
use super::masking::handle_bad_client; use crate::proxy::masking::handle_bad_client;
/// Client connection handler (builder struct) /// Client connection handler (builder struct)
pub struct ClientHandler; pub struct ClientHandler;
@@ -35,6 +36,7 @@ pub struct RunningClientHandler {
stats: Arc<Stats>, stats: Arc<Stats>,
replay_checker: Arc<ReplayChecker>, replay_checker: Arc<ReplayChecker>,
upstream_manager: Arc<UpstreamManager>, upstream_manager: Arc<UpstreamManager>,
buffer_pool: Arc<BufferPool>,
} }
impl ClientHandler { impl ClientHandler {
@@ -45,12 +47,9 @@ impl ClientHandler {
config: Arc<ProxyConfig>, config: Arc<ProxyConfig>,
stats: Arc<Stats>, stats: Arc<Stats>,
upstream_manager: Arc<UpstreamManager>, upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>,
) -> RunningClientHandler { ) -> RunningClientHandler {
// Note: ReplayChecker should be shared globally for proper replay protection
// Creating it per-connection disables replay protection across connections
// TODO: Pass Arc<ReplayChecker> from main.rs
let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len));
RunningClientHandler { RunningClientHandler {
stream, stream,
peer, peer,
@@ -58,6 +57,7 @@ impl ClientHandler {
stats, stats,
replay_checker, replay_checker,
upstream_manager, upstream_manager,
buffer_pool,
} }
} }
} }
@@ -73,14 +73,14 @@ impl RunningClientHandler {
// Configure socket // Configure socket
if let Err(e) = configure_client_socket( if let Err(e) = configure_client_socket(
&self.stream, &self.stream,
self.config.client_keepalive, self.config.timeouts.client_keepalive,
self.config.client_ack_timeout, self.config.timeouts.client_ack,
) { ) {
debug!(peer = %peer, error = %e, "Failed to configure client socket"); debug!(peer = %peer, error = %e, "Failed to configure client socket");
} }
// Perform handshake with timeout // Perform handshake with timeout
let handshake_timeout = Duration::from_secs(self.config.client_handshake_timeout); let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake);
// Clone stats for error handling block // Clone stats for error handling block
let stats = self.stats.clone(); let stats = self.stats.clone();
@@ -140,7 +140,9 @@ impl RunningClientHandler {
if tls_len < 512 { if tls_len < 512 {
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
self.stats.increment_connects_bad(); self.stats.increment_connects_bad();
handle_bad_client(self.stream, &first_bytes, &self.config).await; // FIX: Split stream into reader/writer for handle_bad_client
let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(()); return Ok(());
} }
@@ -153,6 +155,7 @@ impl RunningClientHandler {
let config = self.config.clone(); let config = self.config.clone();
let replay_checker = self.replay_checker.clone(); let replay_checker = self.replay_checker.clone();
let stats = self.stats.clone(); let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone();
// Split stream for reading/writing // Split stream for reading/writing
let (read_half, write_half) = self.stream.into_split(); let (read_half, write_half) = self.stream.into_split();
@@ -167,8 +170,9 @@ impl RunningClientHandler {
&replay_checker, &replay_checker,
).await { ).await {
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient => { HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad(); stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(()); return Ok(());
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
@@ -191,27 +195,23 @@ impl RunningClientHandler {
true, true,
).await { ).await {
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient => { HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad(); stats.increment_connects_bad();
// Valid TLS but invalid MTProto - drop
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake - dropping");
return Ok(()); return Ok(());
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
// Handle authenticated client
// We can't use self.handle_authenticated_inner because self is partially moved
// So we call it as an associated function or method on a new struct,
// or just inline the logic / use a static method.
// Since handle_authenticated_inner needs self.upstream_manager and self.stats,
// we should pass them explicitly.
Self::handle_authenticated_static( Self::handle_authenticated_static(
crypto_reader, crypto_reader,
crypto_writer, crypto_writer,
success, success,
self.upstream_manager, self.upstream_manager,
self.stats, self.stats,
self.config self.config,
buffer_pool
).await ).await
} }
@@ -223,10 +223,12 @@ impl RunningClientHandler {
let peer = self.peer; let peer = self.peer;
// Check if non-TLS modes are enabled // Check if non-TLS modes are enabled
if !self.config.modes.classic && !self.config.modes.secure { if !self.config.general.modes.classic && !self.config.general.modes.secure {
debug!(peer = %peer, "Non-TLS modes disabled"); debug!(peer = %peer, "Non-TLS modes disabled");
self.stats.increment_connects_bad(); self.stats.increment_connects_bad();
handle_bad_client(self.stream, &first_bytes, &self.config).await; // FIX: Split stream into reader/writer for handle_bad_client
let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(()); return Ok(());
} }
@@ -239,6 +241,7 @@ impl RunningClientHandler {
let config = self.config.clone(); let config = self.config.clone();
let replay_checker = self.replay_checker.clone(); let replay_checker = self.replay_checker.clone();
let stats = self.stats.clone(); let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone();
// Split stream // Split stream
let (read_half, write_half) = self.stream.into_split(); let (read_half, write_half) = self.stream.into_split();
@@ -254,8 +257,9 @@ impl RunningClientHandler {
false, false,
).await { ).await {
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient => { HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad(); stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(()); return Ok(());
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
@@ -267,11 +271,12 @@ impl RunningClientHandler {
success, success,
self.upstream_manager, self.upstream_manager,
self.stats, self.stats,
self.config self.config,
buffer_pool
).await ).await
} }
/// Static version of handle_authenticated_inner to avoid ownership issues /// Static version of handle_authenticated_inner
async fn handle_authenticated_static<R, W>( async fn handle_authenticated_static<R, W>(
client_reader: CryptoReader<R>, client_reader: CryptoReader<R>,
client_writer: CryptoWriter<W>, client_writer: CryptoWriter<W>,
@@ -279,6 +284,7 @@ impl RunningClientHandler {
upstream_manager: Arc<UpstreamManager>, upstream_manager: Arc<UpstreamManager>,
stats: Arc<Stats>, stats: Arc<Stats>,
config: Arc<ProxyConfig>, config: Arc<ProxyConfig>,
buffer_pool: Arc<BufferPool>,
) -> Result<()> ) -> Result<()>
where where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
@@ -301,7 +307,7 @@ impl RunningClientHandler {
dc = success.dc_idx, dc = success.dc_idx,
dc_addr = %dc_addr, dc_addr = %dc_addr,
proto = ?success.proto_tag, proto = ?success.proto_tag,
fast_mode = config.fast_mode, fast_mode = config.general.fast_mode,
"Connecting to Telegram" "Connecting to Telegram"
); );
@@ -323,7 +329,7 @@ impl RunningClientHandler {
stats.increment_user_connects(user); stats.increment_user_connects(user);
stats.increment_user_curr_connects(user); stats.increment_user_curr_connects(user);
// Relay traffic // Relay traffic using buffer pool
let relay_result = relay_bidirectional( let relay_result = relay_bidirectional(
client_reader, client_reader,
client_writer, client_writer,
@@ -331,6 +337,7 @@ impl RunningClientHandler {
tg_writer, tg_writer,
user, user,
Arc::clone(&stats), Arc::clone(&stats),
buffer_pool,
).await; ).await;
// Update stats // Update stats
@@ -347,14 +354,14 @@ impl RunningClientHandler {
/// Check user limits (static version) /// Check user limits (static version)
fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> { fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> {
// Check expiration // Check expiration
if let Some(expiration) = config.user_expirations.get(user) { if let Some(expiration) = config.access.user_expirations.get(user) {
if chrono::Utc::now() > *expiration { if chrono::Utc::now() > *expiration {
return Err(ProxyError::UserExpired { user: user.to_string() }); return Err(ProxyError::UserExpired { user: user.to_string() });
} }
} }
// Check connection limit // Check connection limit
if let Some(limit) = config.user_max_tcp_conns.get(user) { if let Some(limit) = config.access.user_max_tcp_conns.get(user) {
let current = stats.get_user_curr_connects(user); let current = stats.get_user_curr_connects(user);
if current >= *limit as u64 { if current >= *limit as u64 {
return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() }); return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() });
@@ -362,7 +369,7 @@ impl RunningClientHandler {
} }
// Check data quota // Check data quota
if let Some(quota) = config.user_data_quota.get(user) { if let Some(quota) = config.access.user_data_quota.get(user) {
let used = stats.get_user_total_octets(user); let used = stats.get_user_total_octets(user);
if used >= *quota { if used >= *quota {
return Err(ProxyError::DataQuotaExceeded { user: user.to_string() }); return Err(ProxyError::DataQuotaExceeded { user: user.to_string() });
@@ -376,7 +383,7 @@ impl RunningClientHandler {
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> { fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
let idx = (dc_idx.abs() - 1) as usize; let idx = (dc_idx.abs() - 1) as usize;
let datacenters = if config.prefer_ipv6 { let datacenters = if config.general.prefer_ipv6 {
&*TG_DATACENTERS_V6 &*TG_DATACENTERS_V6
} else { } else {
&*TG_DATACENTERS_V4 &*TG_DATACENTERS_V4
@@ -400,7 +407,7 @@ impl RunningClientHandler {
success.proto_tag, success.proto_tag,
&success.dec_key, // Client's dec key &success.dec_key, // Client's dec key
success.dec_iv, success.dec_iv,
config.fast_mode, config.general.fast_mode,
); );
// Encrypt nonce // Encrypt nonce

View File

@@ -42,7 +42,7 @@ pub async fn handle_tls_handshake<R, W>(
peer: SocketAddr, peer: SocketAddr,
config: &ProxyConfig, config: &ProxyConfig,
replay_checker: &ReplayChecker, replay_checker: &ReplayChecker,
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String)> ) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String), R, W>
where where
R: AsyncRead + Unpin, R: AsyncRead + Unpin,
W: AsyncWrite + Unpin, W: AsyncWrite + Unpin,
@@ -52,7 +52,7 @@ where
// Check minimum length // Check minimum length
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
debug!(peer = %peer, "TLS handshake too short"); debug!(peer = %peer, "TLS handshake too short");
return HandshakeResult::BadClient; return HandshakeResult::BadClient { reader, writer };
} }
// Extract digest for replay check // Extract digest for replay check
@@ -61,36 +61,38 @@ where
// Check for replay // Check for replay
if replay_checker.check_tls_digest(digest_half) { if replay_checker.check_tls_digest(digest_half) {
warn!(peer = %peer, "TLS replay attack detected"); warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient; return HandshakeResult::BadClient { reader, writer };
} }
// Build secrets list // Build secrets list
let secrets: Vec<(String, Vec<u8>)> = config.users.iter() let secrets: Vec<(String, Vec<u8>)> = config.access.users.iter()
.filter_map(|(name, hex)| { .filter_map(|(name, hex)| {
hex::decode(hex).ok().map(|bytes| (name.clone(), bytes)) hex::decode(hex).ok().map(|bytes| (name.clone(), bytes))
}) })
.collect(); .collect();
debug!(peer = %peer, num_users = secrets.len(), "Validating TLS handshake against users");
// Validate handshake // Validate handshake
let validation = match tls::validate_tls_handshake( let validation = match tls::validate_tls_handshake(
handshake, handshake,
&secrets, &secrets,
config.ignore_time_skew, config.access.ignore_time_skew,
) { ) {
Some(v) => v, Some(v) => v,
None => { None => {
debug!(peer = %peer, "TLS handshake validation failed - no matching user"); debug!(
return HandshakeResult::BadClient; peer = %peer,
ignore_time_skew = config.access.ignore_time_skew,
"TLS handshake validation failed - no matching user or time skew"
);
return HandshakeResult::BadClient { reader, writer };
} }
}; };
// Get secret for response // Get secret for response
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
Some((_, s)) => s, Some((_, s)) => s,
None => return HandshakeResult::BadClient, None => return HandshakeResult::BadClient { reader, writer },
}; };
// Build and send response // Build and send response
@@ -98,20 +100,22 @@ where
secret, secret,
&validation.digest, &validation.digest,
&validation.session_id, &validation.session_id,
config.fake_cert_len, config.censorship.fake_cert_len,
); );
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello"); debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
if let Err(e) = writer.write_all(&response).await { if let Err(e) = writer.write_all(&response).await {
warn!(peer = %peer, error = %e, "Failed to write TLS ServerHello");
return HandshakeResult::Error(ProxyError::Io(e)); return HandshakeResult::Error(ProxyError::Io(e));
} }
if let Err(e) = writer.flush().await { if let Err(e) = writer.flush().await {
warn!(peer = %peer, error = %e, "Failed to flush TLS ServerHello");
return HandshakeResult::Error(ProxyError::Io(e)); return HandshakeResult::Error(ProxyError::Io(e));
} }
// Record for replay protection // Record for replay protection only after successful handshake
replay_checker.add_tls_digest(digest_half); replay_checker.add_tls_digest(digest_half);
info!( info!(
@@ -136,7 +140,7 @@ pub async fn handle_mtproto_handshake<R, W>(
config: &ProxyConfig, config: &ProxyConfig,
replay_checker: &ReplayChecker, replay_checker: &ReplayChecker,
is_tls: bool, is_tls: bool,
) -> HandshakeResult<(CryptoReader<R>, CryptoWriter<W>, HandshakeSuccess)> ) -> HandshakeResult<(CryptoReader<R>, CryptoWriter<W>, HandshakeSuccess), R, W>
where where
R: AsyncRead + Unpin + Send, R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send, W: AsyncWrite + Unpin + Send,
@@ -146,23 +150,17 @@ where
// Extract prekey and IV // Extract prekey and IV
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
debug!(
peer = %peer,
dec_prekey_iv = %hex::encode(dec_prekey_iv),
"Extracted prekey+IV from handshake"
);
// Check for replay // Check for replay
if replay_checker.check_handshake(dec_prekey_iv) { if replay_checker.check_handshake(dec_prekey_iv) {
warn!(peer = %peer, "MTProto replay attack detected"); warn!(peer = %peer, "MTProto replay attack detected");
return HandshakeResult::BadClient; return HandshakeResult::BadClient { reader, writer };
} }
// Reversed for encryption direction // Reversed for encryption direction
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect(); let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
// Try each user's secret // Try each user's secret
for (user, secret_hex) in &config.users { for (user, secret_hex) in &config.access.users {
let secret = match hex::decode(secret_hex) { let secret = match hex::decode(secret_hex) {
Ok(s) => s, Ok(s) => s,
Err(_) => continue, Err(_) => continue,
@@ -183,13 +181,6 @@ where
let mut decryptor = AesCtr::new(&dec_key, dec_iv); let mut decryptor = AesCtr::new(&dec_key, dec_iv);
let decrypted = decryptor.decrypt(handshake); let decrypted = decryptor.decrypt(handshake);
trace!(
peer = %peer,
user = %user,
decrypted_tail = %hex::encode(&decrypted[PROTO_TAG_POS..]),
"Decrypted handshake tail"
);
// Check protocol tag // Check protocol tag
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
.try_into() .try_into()
@@ -197,20 +188,15 @@ where
let proto_tag = match ProtoTag::from_bytes(tag_bytes) { let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
Some(tag) => tag, Some(tag) => tag,
None => { None => continue,
trace!(peer = %peer, user = %user, tag = %hex::encode(tag_bytes), "Invalid proto tag");
continue;
}
}; };
debug!(peer = %peer, user = %user, proto = ?proto_tag, "Found valid proto tag");
// Check if mode is enabled // Check if mode is enabled
let mode_ok = match proto_tag { let mode_ok = match proto_tag {
ProtoTag::Secure => { ProtoTag::Secure => {
if is_tls { config.modes.tls } else { config.modes.secure } if is_tls { config.general.modes.tls } else { config.general.modes.secure }
} }
ProtoTag::Intermediate | ProtoTag::Abridged => config.modes.classic, ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic,
}; };
if !mode_ok { if !mode_ok {
@@ -270,13 +256,10 @@ where
} }
debug!(peer = %peer, "MTProto handshake: no matching user found"); debug!(peer = %peer, "MTProto handshake: no matching user found");
HandshakeResult::BadClient HandshakeResult::BadClient { reader, writer }
} }
/// Generate nonce for Telegram connection /// Generate nonce for Telegram connection
///
/// In FAST MODE: we use the same keys for TG as for client, but reversed.
/// This means: client's enc_key becomes TG's dec_key and vice versa.
pub fn generate_tg_nonce( pub fn generate_tg_nonce(
proto_tag: ProtoTag, proto_tag: ProtoTag,
client_dec_key: &[u8; 32], client_dec_key: &[u8; 32],
@@ -287,39 +270,22 @@ pub fn generate_tg_nonce(
let bytes = SECURE_RANDOM.bytes(HANDSHAKE_LEN); let bytes = SECURE_RANDOM.bytes(HANDSHAKE_LEN);
let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap(); let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap();
// Check reserved patterns if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; }
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) {
continue;
}
let first_four: [u8; 4] = nonce[..4].try_into().unwrap(); let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { continue; }
continue;
}
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap(); let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
if RESERVED_NONCE_CONTINUES.contains(&continue_four) { if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; }
continue;
}
// Set protocol tag
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
// Fast mode: copy client's dec_key+iv (this becomes TG's enc direction)
// In fast mode, we make TG use the same keys as client but swapped:
// - What we decrypt FROM TG = what we encrypt TO client (so no re-encryption needed)
// - What we encrypt TO TG = what we decrypt FROM client
if fast_mode { if fast_mode {
// Put client's dec_key + dec_iv into nonce[8:56]
// This will be used by TG for encryption TO us
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN].copy_from_slice(client_dec_key); nonce[SKIP_LEN..SKIP_LEN + KEY_LEN].copy_from_slice(client_dec_key);
nonce[SKIP_LEN + KEY_LEN..SKIP_LEN + KEY_LEN + IV_LEN] nonce[SKIP_LEN + KEY_LEN..SKIP_LEN + KEY_LEN + IV_LEN]
.copy_from_slice(&client_dec_iv.to_be_bytes()); .copy_from_slice(&client_dec_iv.to_be_bytes());
} }
// Now compute what keys WE will use for TG connection
// enc_key_iv = nonce[8:56] (for encrypting TO TG)
// dec_key_iv = nonce[8:56] reversed (for decrypting FROM TG)
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect(); let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
@@ -329,44 +295,22 @@ pub fn generate_tg_nonce(
let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap(); let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap();
let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap()); let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap());
debug!(
fast_mode = fast_mode,
tg_enc_key = %hex::encode(&tg_enc_key[..8]),
tg_dec_key = %hex::encode(&tg_dec_key[..8]),
"Generated TG nonce"
);
return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv); return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv);
} }
} }
/// Encrypt nonce for sending to Telegram /// Encrypt nonce for sending to Telegram
///
/// Only the part from PROTO_TAG_POS onwards is encrypted.
/// The encryption key is derived from enc_key_iv in the nonce itself.
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> { pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
// enc_key_iv is at nonce[8:56]
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
// Key for encrypting is just the first 32 bytes of enc_key_iv
let key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap(); let key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
let iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap()); let iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
let mut encryptor = AesCtr::new(&key, iv); let mut encryptor = AesCtr::new(&key, iv);
// Encrypt the entire nonce first, then take only the encrypted tail
let encrypted_full = encryptor.encrypt(nonce); let encrypted_full = encryptor.encrypt(nonce);
// Result: unencrypted head + encrypted tail
let mut result = nonce[..PROTO_TAG_POS].to_vec(); let mut result = nonce[..PROTO_TAG_POS].to_vec();
result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]); result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]);
trace!(
original = %hex::encode(&nonce[PROTO_TAG_POS..]),
encrypted = %hex::encode(&result[PROTO_TAG_POS..]),
"Encrypted nonce tail"
);
result result
} }

View File

@@ -1,35 +1,73 @@
//! Masking - forward unrecognized traffic to mask host //! Masking - forward unrecognized traffic to mask host
use std::time::Duration; use std::time::Duration;
use std::str;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tokio::time::timeout; use tokio::time::timeout;
use tracing::debug; use tracing::debug;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::transport::set_linger_zero;
const MASK_TIMEOUT: Duration = Duration::from_secs(5); const MASK_TIMEOUT: Duration = Duration::from_secs(5);
const MASK_BUFFER_SIZE: usize = 8192; const MASK_BUFFER_SIZE: usize = 8192;
/// Detect client type based on initial data
fn detect_client_type(data: &[u8]) -> &'static str {
// Check for HTTP request
if data.len() > 4 {
if data.starts_with(b"GET ") || data.starts_with(b"POST") ||
data.starts_with(b"HEAD") || data.starts_with(b"PUT ") ||
data.starts_with(b"DELETE") || data.starts_with(b"OPTIONS") {
return "HTTP";
}
}
// Check for TLS ClientHello (0x16 = handshake, 0x03 0x01-0x03 = TLS version)
if data.len() > 3 && data[0] == 0x16 && data[1] == 0x03 {
return "TLS-scanner";
}
// Check for SSH
if data.starts_with(b"SSH-") {
return "SSH";
}
// Port scanner (very short data)
if data.len() < 10 {
return "port-scanner";
}
"unknown"
}
/// Handle a bad client by forwarding to mask host /// Handle a bad client by forwarding to mask host
pub async fn handle_bad_client( pub async fn handle_bad_client<R, W>(
client: TcpStream, mut reader: R,
mut writer: W,
initial_data: &[u8], initial_data: &[u8],
config: &ProxyConfig, config: &ProxyConfig,
) { )
if !config.mask { where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
if !config.censorship.mask {
// Masking disabled, just consume data // Masking disabled, just consume data
consume_client_data(client).await; consume_client_data(reader).await;
return; return;
} }
let mask_host = config.mask_host.as_deref() let client_type = detect_client_type(initial_data);
.unwrap_or(&config.tls_domain);
let mask_port = config.mask_port; let mask_host = config.censorship.mask_host.as_deref()
.unwrap_or(&config.censorship.tls_domain);
let mask_port = config.censorship.mask_port;
debug!( debug!(
client_type = client_type,
host = %mask_host, host = %mask_host,
port = mask_port, port = mask_port,
data_len = initial_data.len(),
"Forwarding bad client to mask host" "Forwarding bad client to mask host"
); );
@@ -40,33 +78,32 @@ pub async fn handle_bad_client(
TcpStream::connect(&mask_addr) TcpStream::connect(&mask_addr)
).await; ).await;
let mut mask_stream = match connect_result { let mask_stream = match connect_result {
Ok(Ok(s)) => s, Ok(Ok(s)) => s,
Ok(Err(e)) => { Ok(Err(e)) => {
debug!(error = %e, "Failed to connect to mask host"); debug!(error = %e, "Failed to connect to mask host");
consume_client_data(client).await; consume_client_data(reader).await;
return; return;
} }
Err(_) => { Err(_) => {
debug!("Timeout connecting to mask host"); debug!("Timeout connecting to mask host");
consume_client_data(client).await; consume_client_data(reader).await;
return; return;
} }
}; };
let (mut mask_read, mut mask_write) = mask_stream.into_split();
// Send initial data to mask host // Send initial data to mask host
if mask_stream.write_all(initial_data).await.is_err() { if mask_write.write_all(initial_data).await.is_err() {
return; return;
} }
// Relay traffic // Relay traffic
let (mut client_read, mut client_write) = client.into_split();
let (mut mask_read, mut mask_write) = mask_stream.into_split();
let c2m = tokio::spawn(async move { let c2m = tokio::spawn(async move {
let mut buf = vec![0u8; MASK_BUFFER_SIZE]; let mut buf = vec![0u8; MASK_BUFFER_SIZE];
loop { loop {
match client_read.read(&mut buf).await { match reader.read(&mut buf).await {
Ok(0) | Err(_) => { Ok(0) | Err(_) => {
let _ = mask_write.shutdown().await; let _ = mask_write.shutdown().await;
break; break;
@@ -85,11 +122,11 @@ pub async fn handle_bad_client(
loop { loop {
match mask_read.read(&mut buf).await { match mask_read.read(&mut buf).await {
Ok(0) | Err(_) => { Ok(0) | Err(_) => {
let _ = client_write.shutdown().await; let _ = writer.shutdown().await;
break; break;
} }
Ok(n) => { Ok(n) => {
if client_write.write_all(&buf[..n]).await.is_err() { if writer.write_all(&buf[..n]).await.is_err() {
break; break;
} }
} }
@@ -105,9 +142,9 @@ pub async fn handle_bad_client(
} }
/// Just consume all data from client without responding /// Just consume all data from client without responding
async fn consume_client_data(mut client: TcpStream) { async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R) {
let mut buf = vec![0u8; MASK_BUFFER_SIZE]; let mut buf = vec![0u8; MASK_BUFFER_SIZE];
while let Ok(n) = client.read(&mut buf).await { while let Ok(n) = reader.read(&mut buf).await {
if n == 0 { if n == 0 {
break; break;
} }

View File

@@ -1,13 +1,17 @@
//! Bidirectional Relay //! Bidirectional Relay
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tracing::{debug, trace, warn}; use tokio::time::Instant;
use tracing::{debug, trace, warn, info};
use crate::error::Result; use crate::error::Result;
use crate::stats::Stats; use crate::stats::Stats;
use crate::stream::BufferPool;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
const BUFFER_SIZE: usize = 65536; // Activity timeout for iOS compatibility (30 minutes)
const ACTIVITY_TIMEOUT_SECS: u64 = 1800;
/// Relay data bidirectionally between client and server /// Relay data bidirectionally between client and server
pub async fn relay_bidirectional<CR, CW, SR, SW>( pub async fn relay_bidirectional<CR, CW, SR, SW>(
@@ -17,6 +21,7 @@ pub async fn relay_bidirectional<CR, CW, SR, SW>(
mut server_writer: SW, mut server_writer: SW,
user: &str, user: &str,
stats: Arc<Stats>, stats: Arc<Stats>,
buffer_pool: Arc<BufferPool>,
) -> Result<()> ) -> Result<()>
where where
CR: AsyncRead + Unpin + Send + 'static, CR: AsyncRead + Unpin + Send + 'static,
@@ -27,7 +32,6 @@ where
let user_c2s = user.to_string(); let user_c2s = user.to_string();
let user_s2c = user.to_string(); let user_s2c = user.to_string();
// Используем Arc::clone вместо stats.clone()
let stats_c2s = Arc::clone(&stats); let stats_c2s = Arc::clone(&stats);
let stats_s2c = Arc::clone(&stats); let stats_s2c = Arc::clone(&stats);
@@ -36,15 +40,47 @@ where
let c2s_bytes_clone = Arc::clone(&c2s_bytes); let c2s_bytes_clone = Arc::clone(&c2s_bytes);
let s2c_bytes_clone = Arc::clone(&s2c_bytes); let s2c_bytes_clone = Arc::clone(&s2c_bytes);
let activity_timeout = Duration::from_secs(ACTIVITY_TIMEOUT_SECS);
let pool_c2s = buffer_pool.clone();
let pool_s2c = buffer_pool.clone();
// Client -> Server task // Client -> Server task
let c2s = tokio::spawn(async move { let c2s = tokio::spawn(async move {
let mut buf = vec![0u8; BUFFER_SIZE]; // Get buffer from pool
let mut pooled_buf = pool_c2s.get();
// CRITICAL FIX: BytesMut from pool has len 0. We must resize it to be usable as &mut [u8].
// We use the full capacity.
let cap = pooled_buf.capacity();
pooled_buf.resize(cap, 0);
let mut total_bytes = 0u64; let mut total_bytes = 0u64;
let mut prev_total_bytes = 0u64;
let mut msg_count = 0u64; let mut msg_count = 0u64;
let mut last_activity = Instant::now();
let mut last_log = Instant::now();
loop { loop {
match client_reader.read(&mut buf).await { // Read with timeout
Ok(0) => { let read_result = tokio::time::timeout(
activity_timeout,
client_reader.read(&mut pooled_buf)
).await;
match read_result {
Err(_) => {
warn!(
user = %user_c2s,
total_bytes = total_bytes,
msgs = msg_count,
idle_secs = last_activity.elapsed().as_secs(),
"Activity timeout (C->S) - no data received"
);
let _ = server_writer.shutdown().await;
break;
}
Ok(Ok(0)) => {
debug!( debug!(
user = %user_c2s, user = %user_c2s,
total_bytes = total_bytes, total_bytes = total_bytes,
@@ -54,9 +90,11 @@ where
let _ = server_writer.shutdown().await; let _ = server_writer.shutdown().await;
break; break;
} }
Ok(n) => {
Ok(Ok(n)) => {
total_bytes += n as u64; total_bytes += n as u64;
msg_count += 1; msg_count += 1;
last_activity = Instant::now();
c2s_bytes_clone.store(total_bytes, Ordering::Relaxed); c2s_bytes_clone.store(total_bytes, Ordering::Relaxed);
stats_c2s.add_user_octets_from(&user_c2s, n as u64); stats_c2s.add_user_octets_from(&user_c2s, n as u64);
@@ -66,11 +104,28 @@ where
user = %user_c2s, user = %user_c2s,
bytes = n, bytes = n,
total = total_bytes, total = total_bytes,
data_preview = %hex::encode(&buf[..n.min(32)]),
"C->S data" "C->S data"
); );
if let Err(e) = server_writer.write_all(&buf[..n]).await { // Log activity every 10 seconds with correct rate
let elapsed = last_log.elapsed();
if elapsed > Duration::from_secs(10) {
let delta = total_bytes - prev_total_bytes;
let rate = delta as f64 / elapsed.as_secs_f64();
debug!(
user = %user_c2s,
total_bytes = total_bytes,
msgs = msg_count,
rate_kbps = (rate / 1024.0) as u64,
"C->S transfer in progress"
);
last_log = Instant::now();
prev_total_bytes = total_bytes;
}
if let Err(e) = server_writer.write_all(&pooled_buf[..n]).await {
debug!(user = %user_c2s, error = %e, "Failed to write to server"); debug!(user = %user_c2s, error = %e, "Failed to write to server");
break; break;
} }
@@ -79,7 +134,8 @@ where
break; break;
} }
} }
Err(e) => {
Ok(Err(e)) => {
debug!(user = %user_c2s, error = %e, total_bytes = total_bytes, "Client read error"); debug!(user = %user_c2s, error = %e, total_bytes = total_bytes, "Client read error");
break; break;
} }
@@ -89,13 +145,38 @@ where
// Server -> Client task // Server -> Client task
let s2c = tokio::spawn(async move { let s2c = tokio::spawn(async move {
let mut buf = vec![0u8; BUFFER_SIZE]; // Get buffer from pool
let mut pooled_buf = pool_s2c.get();
// CRITICAL FIX: Resize buffer
let cap = pooled_buf.capacity();
pooled_buf.resize(cap, 0);
let mut total_bytes = 0u64; let mut total_bytes = 0u64;
let mut prev_total_bytes = 0u64;
let mut msg_count = 0u64; let mut msg_count = 0u64;
let mut last_activity = Instant::now();
let mut last_log = Instant::now();
loop { loop {
match server_reader.read(&mut buf).await { let read_result = tokio::time::timeout(
Ok(0) => { activity_timeout,
server_reader.read(&mut pooled_buf)
).await;
match read_result {
Err(_) => {
warn!(
user = %user_s2c,
total_bytes = total_bytes,
msgs = msg_count,
idle_secs = last_activity.elapsed().as_secs(),
"Activity timeout (S->C) - no data received"
);
let _ = client_writer.shutdown().await;
break;
}
Ok(Ok(0)) => {
debug!( debug!(
user = %user_s2c, user = %user_s2c,
total_bytes = total_bytes, total_bytes = total_bytes,
@@ -105,9 +186,11 @@ where
let _ = client_writer.shutdown().await; let _ = client_writer.shutdown().await;
break; break;
} }
Ok(n) => {
Ok(Ok(n)) => {
total_bytes += n as u64; total_bytes += n as u64;
msg_count += 1; msg_count += 1;
last_activity = Instant::now();
s2c_bytes_clone.store(total_bytes, Ordering::Relaxed); s2c_bytes_clone.store(total_bytes, Ordering::Relaxed);
stats_s2c.add_user_octets_to(&user_s2c, n as u64); stats_s2c.add_user_octets_to(&user_s2c, n as u64);
@@ -117,11 +200,27 @@ where
user = %user_s2c, user = %user_s2c,
bytes = n, bytes = n,
total = total_bytes, total = total_bytes,
data_preview = %hex::encode(&buf[..n.min(32)]),
"S->C data" "S->C data"
); );
if let Err(e) = client_writer.write_all(&buf[..n]).await { let elapsed = last_log.elapsed();
if elapsed > Duration::from_secs(10) {
let delta = total_bytes - prev_total_bytes;
let rate = delta as f64 / elapsed.as_secs_f64();
debug!(
user = %user_s2c,
total_bytes = total_bytes,
msgs = msg_count,
rate_kbps = (rate / 1024.0) as u64,
"S->C transfer in progress"
);
last_log = Instant::now();
prev_total_bytes = total_bytes;
}
if let Err(e) = client_writer.write_all(&pooled_buf[..n]).await {
debug!(user = %user_s2c, error = %e, "Failed to write to client"); debug!(user = %user_s2c, error = %e, "Failed to write to client");
break; break;
} }
@@ -130,7 +229,8 @@ where
break; break;
} }
} }
Err(e) => {
Ok(Err(e)) => {
debug!(user = %user_s2c, error = %e, total_bytes = total_bytes, "Server read error"); debug!(user = %user_s2c, error = %e, total_bytes = total_bytes, "Server read error");
break; break;
} }

View File

@@ -4,9 +4,11 @@ use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use dashmap::DashMap; use dashmap::DashMap;
use parking_lot::RwLock; use parking_lot::{RwLock, Mutex};
use lru::LruCache; use lru::LruCache;
use std::num::NonZeroUsize; use std::num::NonZeroUsize;
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
/// Thread-safe statistics /// Thread-safe statistics
#[derive(Default)] #[derive(Default)]
@@ -141,37 +143,57 @@ impl Stats {
} }
} }
// Arc<Stats> Hightech Stats :D /// Sharded Replay attack checker using LRU cache
/// Uses multiple independent LRU caches to reduce lock contention
/// Replay attack checker using LRU cache
pub struct ReplayChecker { pub struct ReplayChecker {
handshakes: RwLock<LruCache<Vec<u8>, ()>>, shards: Vec<Mutex<LruCache<Vec<u8>, ()>>>,
tls_digests: RwLock<LruCache<Vec<u8>, ()>>, shard_mask: usize,
} }
impl ReplayChecker { impl ReplayChecker {
pub fn new(capacity: usize) -> Self { /// Create new replay checker with specified capacity per shard
let cap = NonZeroUsize::new(capacity.max(1)).unwrap(); /// Total capacity = capacity * num_shards
pub fn new(total_capacity: usize) -> Self {
// Use 64 shards for good concurrency
let num_shards = 64;
let shard_capacity = (total_capacity / num_shards).max(1);
let cap = NonZeroUsize::new(shard_capacity).unwrap();
let mut shards = Vec::with_capacity(num_shards);
for _ in 0..num_shards {
shards.push(Mutex::new(LruCache::new(cap)));
}
Self { Self {
handshakes: RwLock::new(LruCache::new(cap)), shards,
tls_digests: RwLock::new(LruCache::new(cap)), shard_mask: num_shards - 1,
} }
} }
fn get_shard(&self, key: &[u8]) -> usize {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
(hasher.finish() as usize) & self.shard_mask
}
pub fn check_handshake(&self, data: &[u8]) -> bool { pub fn check_handshake(&self, data: &[u8]) -> bool {
self.handshakes.read().contains(&data.to_vec()) let shard_idx = self.get_shard(data);
self.shards[shard_idx].lock().contains(&data.to_vec())
} }
pub fn add_handshake(&self, data: &[u8]) { pub fn add_handshake(&self, data: &[u8]) {
self.handshakes.write().put(data.to_vec(), ()); let shard_idx = self.get_shard(data);
self.shards[shard_idx].lock().put(data.to_vec(), ());
} }
pub fn check_tls_digest(&self, data: &[u8]) -> bool { pub fn check_tls_digest(&self, data: &[u8]) -> bool {
self.tls_digests.read().contains(&data.to_vec()) let shard_idx = self.get_shard(data);
self.shards[shard_idx].lock().contains(&data.to_vec())
} }
pub fn add_tls_digest(&self, data: &[u8]) { pub fn add_tls_digest(&self, data: &[u8]) {
self.tls_digests.write().put(data.to_vec(), ()); let shard_idx = self.get_shard(data);
self.shards[shard_idx].lock().put(data.to_vec(), ());
} }
} }
@@ -183,7 +205,6 @@ mod tests {
fn test_stats_shared_counters() { fn test_stats_shared_counters() {
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
// Симулируем использование из разных "задач"
let stats1 = Arc::clone(&stats); let stats1 = Arc::clone(&stats);
let stats2 = Arc::clone(&stats); let stats2 = Arc::clone(&stats);
@@ -191,33 +212,20 @@ mod tests {
stats2.increment_connects_all(); stats2.increment_connects_all();
stats1.increment_connects_all(); stats1.increment_connects_all();
// Все инкременты должны быть видны
assert_eq!(stats.get_connects_all(), 3); assert_eq!(stats.get_connects_all(), 3);
} }
#[test] #[test]
fn test_user_stats_shared() { fn test_replay_checker_sharding() {
let stats = Arc::new(Stats::new()); let checker = ReplayChecker::new(100);
let data1 = b"test1";
let data2 = b"test2";
let stats1 = Arc::clone(&stats); checker.add_handshake(data1);
let stats2 = Arc::clone(&stats); assert!(checker.check_handshake(data1));
assert!(!checker.check_handshake(data2));
stats1.add_user_octets_from("user1", 100); checker.add_handshake(data2);
stats2.add_user_octets_from("user1", 200); assert!(checker.check_handshake(data2));
stats1.add_user_octets_to("user1", 50);
assert_eq!(stats.get_user_total_octets("user1"), 350);
}
#[test]
fn test_concurrent_user_connects() {
let stats = Arc::new(Stats::new());
stats.increment_user_curr_connects("user1");
stats.increment_user_curr_connects("user1");
assert_eq!(stats.get_user_curr_connects("user1"), 2);
stats.decrement_user_curr_connects("user1");
assert_eq!(stats.get_user_curr_connects("user1"), 1);
} }
} }

View File

@@ -11,8 +11,9 @@ use std::sync::Arc;
// ============= Configuration ============= // ============= Configuration =============
/// Default buffer size (64KB - good for MTProto) /// Default buffer size
pub const DEFAULT_BUFFER_SIZE: usize = 64 * 1024; /// CHANGED: Reduced from 64KB to 16KB to match TLS record size and prevent bufferbloat.
pub const DEFAULT_BUFFER_SIZE: usize = 16 * 1024;
/// Default maximum number of pooled buffers /// Default maximum number of pooled buffers
pub const DEFAULT_MAX_BUFFERS: usize = 1024; pub const DEFAULT_MAX_BUFFERS: usize = 1024;

File diff suppressed because it is too large Load Diff

View File

@@ -1,17 +1,36 @@
//! Fake TLS 1.3 stream wrappers //! Fake TLS 1.3 stream wrappers
//! //!
//! This module provides stateful async stream wrappers that handle //! This module provides stateful async stream wrappers that handle TLS record
//! TLS record framing with proper partial read/write handling. //! framing with proper partial read/write handling.
//! //!
//! These are "fake" TLS streams - they wrap data in valid TLS 1.3 //! These are "fake" TLS streams:
//! Application Data records but don't perform actual TLS encryption. //! - We wrap raw bytes into syntactically valid TLS 1.3 records (Application Data).
//! The actual encryption is handled by the crypto layer underneath. //! - We DO NOT perform real TLS handshake/encryption.
//! - Real crypto for MTProto is handled by the crypto layer underneath.
//!
//! Why do we need this?
//! Telegram MTProto proxy "FakeTLS" mode uses a TLS-looking outer layer for
//! domain fronting / traffic camouflage. iOS Telegram clients are known to
//! produce slightly different TLS record sizing patterns than Android/Desktop,
//! including records that exceed 16384 payload bytes by a small overhead.
//! //!
//! Key design principles: //! Key design principles:
//! - Explicit state machines for all async operations //! - Explicit state machines for all async operations
//! - Never lose data on partial reads //! - Never lose data on partial reads
//! - Atomic TLS record formation for writes //! - Atomic TLS record formation for writes
//! - Proper handling of all TLS record types //! - Proper handling of all TLS record types
//!
//! Important nuance (Telegram FakeTLS):
//! - The TLS spec limits "plaintext fragments" to 2^14 (16384) bytes.
//! - However, the on-the-wire record length can exceed 16384 because TLS 1.3
//! uses AEAD and can include tag/overhead/padding.
//! - Telegram FakeTLS clients (notably iOS) may send Application Data records
//! with length up to 16384 + 24 bytes. We accept that as MAX_TLS_CHUNK_SIZE.
//!
//! If you reject those (e.g. validate length <= 16384), you will see errors like:
//! "TLS record too large: 16408 bytes"
//! and uploads from iOS will break (media/file sending), while small traffic
//! may still work.
use bytes::{Bytes, BytesMut, BufMut}; use bytes::{Bytes, BytesMut, BufMut};
use std::io::{self, Error, ErrorKind, Result}; use std::io::{self, Error, ErrorKind, Result};
@@ -20,25 +39,29 @@ use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf}; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf};
use crate::protocol::constants::{ use crate::protocol::constants::{
TLS_VERSION, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_VERSION,
TLS_RECORD_HANDSHAKE, TLS_RECORD_ALERT, MAX_TLS_RECORD_SIZE, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER,
TLS_RECORD_HANDSHAKE, TLS_RECORD_ALERT,
MAX_TLS_CHUNK_SIZE,
}; };
use super::state::{StreamState, HeaderBuffer, YieldBuffer, WriteBuffer}; use super::state::{StreamState, HeaderBuffer, YieldBuffer, WriteBuffer};
// ============= Constants ============= // ============= Constants =============
/// TLS record header size /// TLS record header size (type + version + length)
const TLS_HEADER_SIZE: usize = 5; const TLS_HEADER_SIZE: usize = 5;
/// Maximum TLS record payload size (16KB as per TLS spec) /// Maximum TLS fragment size per spec (plaintext fragment).
/// We use this for *outgoing* chunking, because we build plain ApplicationData records.
const MAX_TLS_PAYLOAD: usize = 16384; const MAX_TLS_PAYLOAD: usize = 16384;
/// Maximum pending write buffer /// Maximum pending write buffer for one record remainder.
/// Note: we never queue unlimited amount of data here; state holds at most one record.
const MAX_PENDING_WRITE: usize = 64 * 1024; const MAX_PENDING_WRITE: usize = 64 * 1024;
// ============= TLS Record Types ============= // ============= TLS Record Types =============
/// Parsed TLS record header /// Parsed TLS record header (5 bytes)
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
struct TlsRecordHeader { struct TlsRecordHeader {
/// Record type (0x17 = Application Data, 0x14 = Change Cipher, etc.) /// Record type (0x17 = Application Data, 0x14 = Change Cipher, etc.)
@@ -50,22 +73,27 @@ struct TlsRecordHeader {
} }
impl TlsRecordHeader { impl TlsRecordHeader {
/// Parse header from 5 bytes /// Parse header from exactly 5 bytes.
///
/// This currently never returns None, but is kept as Option to allow future
/// stricter parsing rules without changing callers.
fn parse(header: &[u8; 5]) -> Option<Self> { fn parse(header: &[u8; 5]) -> Option<Self> {
let record_type = header[0]; let record_type = header[0];
let version = [header[1], header[2]]; let version = [header[1], header[2]];
let length = u16::from_be_bytes([header[3], header[4]]); let length = u16::from_be_bytes([header[3], header[4]]);
Some(Self { record_type, version, length })
Some(Self {
record_type,
version,
length,
})
} }
/// Validate the header /// Validate the header.
///
/// Nuances:
/// - We accept TLS 1.0 header version for ClientHello-like records (0x03 0x01),
/// and TLS 1.2/1.3 style version bytes for the rest (we use TLS_VERSION = 0x03 0x03).
/// - For Application Data, Telegram FakeTLS may send payload length up to
/// MAX_TLS_CHUNK_SIZE (16384 + 24).
/// - For other record types we keep stricter bounds to avoid memory abuse.
fn validate(&self) -> Result<()> { fn validate(&self) -> Result<()> {
// Check version (accept TLS 1.0 for ClientHello, TLS 1.2/1.3 for others) // Version: accept TLS 1.0 header (ClientHello quirk) and TLS_VERSION (0x0303).
if self.version != [0x03, 0x01] && self.version != TLS_VERSION { if self.version != [0x03, 0x01] && self.version != TLS_VERSION {
return Err(Error::new( return Err(Error::new(
ErrorKind::InvalidData, ErrorKind::InvalidData,
@@ -73,27 +101,36 @@ impl TlsRecordHeader {
)); ));
} }
// Check length let len = self.length as usize;
if self.length as usize > MAX_TLS_RECORD_SIZE {
return Err(Error::new( // Length checks depend on record type.
ErrorKind::InvalidData, // Telegram FakeTLS: ApplicationData length may be 16384 + 24.
format!("TLS record too large: {} bytes", self.length), match self.record_type {
)); TLS_RECORD_APPLICATION => {
if len > MAX_TLS_CHUNK_SIZE {
return Err(Error::new(
ErrorKind::InvalidData,
format!("TLS record too large: {} bytes (max {})", len, MAX_TLS_CHUNK_SIZE),
));
}
}
// ChangeCipherSpec/Alert/Handshake should never be that large for our usage
// (post-handshake we don't expect Handshake at all).
// Keep strict to reduce attack surface.
_ => {
if len > MAX_TLS_PAYLOAD {
return Err(Error::new(
ErrorKind::InvalidData,
format!("TLS control record too large: {} bytes (max {})", len, MAX_TLS_PAYLOAD),
));
}
}
} }
Ok(()) Ok(())
} }
/// Check if this is an application data record
fn is_application_data(&self) -> bool {
self.record_type == TLS_RECORD_APPLICATION
}
/// Check if this is a change cipher spec record (should be skipped)
fn is_change_cipher_spec(&self) -> bool {
self.record_type == TLS_RECORD_CHANGE_CIPHER
}
/// Build header bytes /// Build header bytes
fn to_bytes(&self) -> [u8; 5] { fn to_bytes(&self) -> [u8; 5] {
[ [
@@ -120,25 +157,20 @@ enum TlsReaderState {
header: HeaderBuffer<TLS_HEADER_SIZE>, header: HeaderBuffer<TLS_HEADER_SIZE>,
}, },
/// Reading the TLS record body /// Reading the TLS record body (payload)
ReadingBody { ReadingBody {
/// Parsed record type
record_type: u8, record_type: u8,
/// Total body length
length: usize, length: usize,
/// Buffer for body data
buffer: BytesMut, buffer: BytesMut,
}, },
/// Have decrypted data ready to yield to caller /// Have buffered data ready to yield to caller
Yielding { Yielding {
/// Buffer containing data to yield
buffer: YieldBuffer, buffer: YieldBuffer,
}, },
/// Stream encountered an error and cannot be used /// Stream encountered an error and cannot be used
Poisoned { Poisoned {
/// The error that caused poisoning
error: Option<io::Error>, error: Option<io::Error>,
}, },
} }
@@ -165,12 +197,13 @@ impl StreamState for TlsReaderState {
// ============= FakeTlsReader ============= // ============= FakeTlsReader =============
/// Reader that unwraps TLS 1.3 records with proper state machine /// Reader that unwraps TLS records (FakeTLS).
/// ///
/// This reader handles partial reads correctly by maintaining internal state /// This wrapper is responsible ONLY for TLS record framing and skipping
/// and never losing any data that has been read from upstream. /// non-data records (like CCS). It does not decrypt TLS: payload bytes are passed
/// as-is to upper layers (crypto stream).
/// ///
/// # State Machine /// State machine overview:
/// ///
/// ┌──────────┐ ┌───────────────┐ /// ┌──────────┐ ┌───────────────┐
/// │ Idle │ -----------------> │ ReadingHeader │ /// │ Idle │ -----------------> │ ReadingHeader │
@@ -178,103 +211,69 @@ impl StreamState for TlsReaderState {
/// ▲ │ /// ▲ │
/// │ header complete /// │ header complete
/// │ │ /// │ │
/// │ /// │
/// │ ┌───────────────┐ /// │ ┌───────────────┐
/// │ skip record │ ReadingBody │ /// │ skip record │ ReadingBody │
/// │ <-------- (CCS) -------- │ │ /// │ <-------- (CCS) -------- │ │
/// │ └───────┬───────┘ /// │ └───────┬───────┘
/// │ │ /// │ │
/// │ body complete /// │ body complete
/// │ drained /// │
/// │ <-----------------┐ │ /// │ ┌───────────────┐
/// │ │ ┌───────────────┐ /// │ │ Yielding │
/// │ └----- │ Yielding │
/// │ └───────────────┘ /// │ └───────────────┘
/// │ /// │
/// │ errors /w any state /// │ errors / w any state
/// ///
/// ┌───────────────────────────────────────────────┐ /// ┌───────────────────────────────────────────────┐
/// │ Poisoned │ /// │ Poisoned │
/// └───────────────────────────────────────────────┘ /// └───────────────────────────────────────────────┘
/// ///
/// NOTE: We must correctly handle partial reads from upstream:
/// - do not assume header arrives in one poll
/// - do not assume body arrives in one poll
/// - never lose already-read bytes
pub struct FakeTlsReader<R> { pub struct FakeTlsReader<R> {
/// Upstream reader
upstream: R, upstream: R,
/// Current state
state: TlsReaderState, state: TlsReaderState,
} }
impl<R> FakeTlsReader<R> { impl<R> FakeTlsReader<R> {
/// Create new fake TLS reader
pub fn new(upstream: R) -> Self { pub fn new(upstream: R) -> Self {
Self { Self { upstream, state: TlsReaderState::Idle }
upstream,
state: TlsReaderState::Idle,
}
} }
/// Get reference to upstream pub fn get_ref(&self) -> &R { &self.upstream }
pub fn get_ref(&self) -> &R { pub fn get_mut(&mut self) -> &mut R { &mut self.upstream }
&self.upstream pub fn into_inner(self) -> R { self.upstream }
}
/// Get mutable reference to upstream pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() }
pub fn get_mut(&mut self) -> &mut R { pub fn state_name(&self) -> &'static str { self.state.state_name() }
&mut self.upstream
}
/// Consume and return upstream
pub fn into_inner(self) -> R {
self.upstream
}
/// Check if stream is in poisoned state
pub fn is_poisoned(&self) -> bool {
self.state.is_poisoned()
}
/// Get current state name (for debugging)
pub fn state_name(&self) -> &'static str {
self.state.state_name()
}
/// Transition to poisoned state
fn poison(&mut self, error: io::Error) { fn poison(&mut self, error: io::Error) {
self.state = TlsReaderState::Poisoned { error: Some(error) }; self.state = TlsReaderState::Poisoned { error: Some(error) };
} }
/// Take error from poisoned state
fn take_poison_error(&mut self) -> io::Error { fn take_poison_error(&mut self) -> io::Error {
match &mut self.state { match &mut self.state {
TlsReaderState::Poisoned { error } => { TlsReaderState::Poisoned { error } => error.take().unwrap_or_else(|| {
error.take().unwrap_or_else(|| { io::Error::new(ErrorKind::Other, "stream previously poisoned")
io::Error::new(ErrorKind::Other, "stream previously poisoned") }),
})
}
_ => io::Error::new(ErrorKind::Other, "stream not poisoned"), _ => io::Error::new(ErrorKind::Other, "stream not poisoned"),
} }
} }
} }
/// Result of polling for header completion
enum HeaderPollResult { enum HeaderPollResult {
/// Need more data
Pending, Pending,
/// EOF at record boundary (clean close)
Eof, Eof,
/// Header complete, parsed successfully
Complete(TlsRecordHeader), Complete(TlsRecordHeader),
/// Error occurred
Error(io::Error), Error(io::Error),
} }
/// Result of polling for body completion
enum BodyPollResult { enum BodyPollResult {
/// Need more data
Pending, Pending,
/// Body complete
Complete(Bytes), Complete(Bytes),
/// Error occurred
Error(io::Error), Error(io::Error),
} }
@@ -291,7 +290,7 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
let state = std::mem::replace(&mut this.state, TlsReaderState::Idle); let state = std::mem::replace(&mut this.state, TlsReaderState::Idle);
match state { match state {
// Poisoned state - return error // Poisoned state: always return the stored error
TlsReaderState::Poisoned { error } => { TlsReaderState::Poisoned { error } => {
this.state = TlsReaderState::Poisoned { error: None }; this.state = TlsReaderState::Poisoned { error: None };
let err = error.unwrap_or_else(|| { let err = error.unwrap_or_else(|| {
@@ -300,20 +299,18 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
return Poll::Ready(Err(err)); return Poll::Ready(Err(err));
} }
// Have buffered data to yield // Yield buffered plaintext to caller
TlsReaderState::Yielding { mut buffer } => { TlsReaderState::Yielding { mut buffer } => {
if buf.remaining() == 0 { if buf.remaining() == 0 {
this.state = TlsReaderState::Yielding { buffer }; this.state = TlsReaderState::Yielding { buffer };
return Poll::Ready(Ok(())); return Poll::Ready(Ok(()));
} }
// Copy as much as possible to output
let to_copy = buffer.remaining().min(buf.remaining()); let to_copy = buffer.remaining().min(buf.remaining());
let dst = buf.initialize_unfilled_to(to_copy); let dst = buf.initialize_unfilled_to(to_copy);
let copied = buffer.copy_to(dst); let copied = buffer.copy_to(dst);
buf.advance(copied); buf.advance(copied);
// If buffer is drained, transition to Idle
if buffer.is_empty() { if buffer.is_empty() {
this.state = TlsReaderState::Idle; this.state = TlsReaderState::Idle;
} else { } else {
@@ -323,23 +320,21 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
return Poll::Ready(Ok(())); return Poll::Ready(Ok(()));
} }
// Ready to read a new TLS record // Start reading new record
TlsReaderState::Idle => { TlsReaderState::Idle => {
if buf.remaining() == 0 { if buf.remaining() == 0 {
this.state = TlsReaderState::Idle; this.state = TlsReaderState::Idle;
return Poll::Ready(Ok(())); return Poll::Ready(Ok(()));
} }
// Start reading header
this.state = TlsReaderState::ReadingHeader { this.state = TlsReaderState::ReadingHeader {
header: HeaderBuffer::new(), header: HeaderBuffer::new(),
}; };
// Continue to ReadingHeader // loop continues and will handle ReadingHeader
} }
// Reading TLS record header // Read TLS header (5 bytes)
TlsReaderState::ReadingHeader { mut header } => { TlsReaderState::ReadingHeader { mut header } => {
// Poll to fill header
let result = poll_read_header(&mut this.upstream, cx, &mut header); let result = poll_read_header(&mut this.upstream, cx, &mut header);
match result { match result {
@@ -348,6 +343,7 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
return Poll::Pending; return Poll::Pending;
} }
HeaderPollResult::Eof => { HeaderPollResult::Eof => {
// Clean EOF at record boundary
this.state = TlsReaderState::Idle; this.state = TlsReaderState::Idle;
return Poll::Ready(Ok(())); return Poll::Ready(Ok(()));
} }
@@ -356,15 +352,12 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
return Poll::Ready(Err(e)); return Poll::Ready(Err(e));
} }
HeaderPollResult::Complete(parsed) => { HeaderPollResult::Complete(parsed) => {
// Validate header
if let Err(e) = parsed.validate() { if let Err(e) = parsed.validate() {
this.poison(Error::new(e.kind(), e.to_string())); this.poison(Error::new(e.kind(), e.to_string()));
return Poll::Ready(Err(e)); return Poll::Ready(Err(e));
} }
let length = parsed.length as usize; let length = parsed.length as usize;
// Transition to reading body
this.state = TlsReaderState::ReadingBody { this.state = TlsReaderState::ReadingBody {
record_type: parsed.record_type, record_type: parsed.record_type,
length, length,
@@ -374,7 +367,7 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
} }
} }
// Reading TLS record body // Read TLS payload
TlsReaderState::ReadingBody { record_type, length, mut buffer } => { TlsReaderState::ReadingBody { record_type, length, mut buffer } => {
let result = poll_read_body(&mut this.upstream, cx, &mut buffer, length); let result = poll_read_body(&mut this.upstream, cx, &mut buffer, length);
@@ -388,15 +381,15 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
return Poll::Ready(Err(e)); return Poll::Ready(Err(e));
} }
BodyPollResult::Complete(data) => { BodyPollResult::Complete(data) => {
// Handle different record types
match record_type { match record_type {
TLS_RECORD_CHANGE_CIPHER => { TLS_RECORD_CHANGE_CIPHER => {
// Skip Change Cipher Spec, read next record // CCS is expected in some clients, ignore it.
this.state = TlsReaderState::Idle; this.state = TlsReaderState::Idle;
continue; continue;
} }
TLS_RECORD_APPLICATION => { TLS_RECORD_APPLICATION => {
// Application data - yield to caller // This is what we actually want.
if data.is_empty() { if data.is_empty() {
this.state = TlsReaderState::Idle; this.state = TlsReaderState::Idle;
continue; continue;
@@ -405,25 +398,26 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
this.state = TlsReaderState::Yielding { this.state = TlsReaderState::Yielding {
buffer: YieldBuffer::new(data), buffer: YieldBuffer::new(data),
}; };
// Continue to yield // loop continues and will yield immediately
} }
TLS_RECORD_ALERT => { TLS_RECORD_ALERT => {
// TLS Alert - treat as EOF // Treat TLS alert as EOF-like termination.
this.state = TlsReaderState::Idle; this.state = TlsReaderState::Idle;
return Poll::Ready(Ok(())); return Poll::Ready(Ok(()));
} }
TLS_RECORD_HANDSHAKE => { TLS_RECORD_HANDSHAKE => {
let err = Error::new( // After FakeTLS handshake is done, we do not expect any Handshake records.
ErrorKind::InvalidData, let err = Error::new(ErrorKind::InvalidData, "unexpected TLS handshake record");
"unexpected TLS handshake record"
);
this.poison(Error::new(err.kind(), err.to_string())); this.poison(Error::new(err.kind(), err.to_string()));
return Poll::Ready(Err(err)); return Poll::Ready(Err(err));
} }
_ => { _ => {
let err = Error::new( let err = Error::new(
ErrorKind::InvalidData, ErrorKind::InvalidData,
format!("unknown TLS record type: 0x{:02x}", record_type) format!("unknown TLS record type: 0x{:02x}", record_type),
); );
this.poison(Error::new(err.kind(), err.to_string())); this.poison(Error::new(err.kind(), err.to_string()));
return Poll::Ready(Err(err)); return Poll::Ready(Err(err));
@@ -459,8 +453,10 @@ fn poll_read_header<R: AsyncRead + Unpin>(
} else { } else {
return HeaderPollResult::Error(Error::new( return HeaderPollResult::Error(Error::new(
ErrorKind::UnexpectedEof, ErrorKind::UnexpectedEof,
format!("unexpected EOF in TLS header (got {} of 5 bytes)", format!(
header.as_slice().len()) "unexpected EOF in TLS header (got {} of 5 bytes)",
header.as_slice().len()
),
)); ));
} }
} }
@@ -469,14 +465,10 @@ fn poll_read_header<R: AsyncRead + Unpin>(
} }
} }
// Parse header
let header_bytes = *header.as_array(); let header_bytes = *header.as_array();
match TlsRecordHeader::parse(&header_bytes) { match TlsRecordHeader::parse(&header_bytes) {
Some(h) => HeaderPollResult::Complete(h), Some(h) => HeaderPollResult::Complete(h),
None => HeaderPollResult::Error(Error::new( None => HeaderPollResult::Error(Error::new(ErrorKind::InvalidData, "failed to parse TLS header")),
ErrorKind::InvalidData,
"failed to parse TLS header"
)),
} }
} }
@@ -487,10 +479,12 @@ fn poll_read_body<R: AsyncRead + Unpin>(
buffer: &mut BytesMut, buffer: &mut BytesMut,
target_len: usize, target_len: usize,
) -> BodyPollResult { ) -> BodyPollResult {
// NOTE: This implementation uses a temporary Vec to avoid tricky borrow/lifetime
// issues with BytesMut spare capacity and ReadBuf across polls.
// It's safe and correct; optimization is possible if needed.
while buffer.len() < target_len { while buffer.len() < target_len {
let remaining = target_len - buffer.len(); let remaining = target_len - buffer.len();
// Read into a temporary buffer
let mut temp = vec![0u8; remaining.min(8192)]; let mut temp = vec![0u8; remaining.min(8192)];
let mut read_buf = ReadBuf::new(&mut temp); let mut read_buf = ReadBuf::new(&mut temp);
@@ -502,8 +496,11 @@ fn poll_read_body<R: AsyncRead + Unpin>(
if n == 0 { if n == 0 {
return BodyPollResult::Error(Error::new( return BodyPollResult::Error(Error::new(
ErrorKind::UnexpectedEof, ErrorKind::UnexpectedEof,
format!("unexpected EOF in TLS body (got {} of {} bytes)", format!(
buffer.len(), target_len) "unexpected EOF in TLS body (got {} of {} bytes)",
buffer.len(),
target_len
),
)); ));
} }
buffer.extend_from_slice(&temp[..n]); buffer.extend_from_slice(&temp[..n]);
@@ -515,10 +512,9 @@ fn poll_read_body<R: AsyncRead + Unpin>(
} }
impl<R: AsyncRead + Unpin> FakeTlsReader<R> { impl<R: AsyncRead + Unpin> FakeTlsReader<R> {
/// Read exactly n bytes through TLS layer /// Read exactly n bytes through TLS layer.
/// ///
/// This is a convenience method that accumulates data across /// This accumulates data across multiple TLS ApplicationData records.
/// multiple TLS records until exactly n bytes are available.
pub async fn read_exact(&mut self, n: usize) -> Result<Bytes> { pub async fn read_exact(&mut self, n: usize) -> Result<Bytes> {
if self.is_poisoned() { if self.is_poisoned() {
return Err(self.take_poison_error()); return Err(self.take_poison_error());
@@ -533,7 +529,7 @@ impl<R: AsyncRead + Unpin> FakeTlsReader<R> {
if read == 0 { if read == 0 {
return Err(Error::new( return Err(Error::new(
ErrorKind::UnexpectedEof, ErrorKind::UnexpectedEof,
format!("expected {} bytes, got {}", n, result.len()) format!("expected {} bytes, got {}", n, result.len()),
)); ));
} }
@@ -546,23 +542,19 @@ impl<R: AsyncRead + Unpin> FakeTlsReader<R> {
// ============= FakeTlsWriter State ============= // ============= FakeTlsWriter State =============
/// State machine states for FakeTlsWriter
#[derive(Debug)] #[derive(Debug)]
enum TlsWriterState { enum TlsWriterState {
/// Ready to accept new data /// Ready to accept new data
Idle, Idle,
/// Writing a complete TLS record /// Writing a complete TLS record (header + body), possibly partially
WritingRecord { WritingRecord {
/// Complete record (header + body) to write
record: WriteBuffer, record: WriteBuffer,
/// Original payload size (for return value calculation)
payload_size: usize, payload_size: usize,
}, },
/// Stream encountered an error and cannot be used /// Stream encountered an error and cannot be used
Poisoned { Poisoned {
/// The error that caused poisoning
error: Option<io::Error>, error: Option<io::Error>,
}, },
} }
@@ -587,94 +579,46 @@ impl StreamState for TlsWriterState {
// ============= FakeTlsWriter ============= // ============= FakeTlsWriter =============
/// Writer that wraps data in TLS 1.3 records with proper state machine /// Writer that wraps bytes into TLS 1.3 Application Data records.
/// ///
/// This writer handles partial writes correctly by: /// We chunk outgoing data into records of <= 16384 payload bytes (MAX_TLS_PAYLOAD).
/// - Building complete TLS records before writing /// We do not try to mimic AEAD overhead on the wire; Telegram clients accept it.
/// - Maintaining internal state for partial record writes /// If you want to be more camouflage-accurate later, you could add optional padding
/// - Never splitting a record mid-write to upstream /// to produce records sized closer to MAX_TLS_CHUNK_SIZE.
///
/// # State Machine
///
/// ┌──────────┐ write ┌─────────────────┐
/// │ Idle │ -------------> │ WritingRecord │
/// │ │ <------------- │ │
/// └──────────┘ complete └─────────────────┘
/// │ │
/// │ < errors > │
/// │ │
/// ┌─────────────────────────────────────────────┐
/// │ Poisoned │
/// └─────────────────────────────────────────────┘
///
/// # Record Formation
///
/// Data is chunked into records of at most MAX_TLS_PAYLOAD bytes.
/// Each record has a 5-byte header prepended.
pub struct FakeTlsWriter<W> { pub struct FakeTlsWriter<W> {
/// Upstream writer
upstream: W, upstream: W,
/// Current state
state: TlsWriterState, state: TlsWriterState,
} }
impl<W> FakeTlsWriter<W> { impl<W> FakeTlsWriter<W> {
/// Create new fake TLS writer
pub fn new(upstream: W) -> Self { pub fn new(upstream: W) -> Self {
Self { Self { upstream, state: TlsWriterState::Idle }
upstream,
state: TlsWriterState::Idle,
}
} }
/// Get reference to upstream pub fn get_ref(&self) -> &W { &self.upstream }
pub fn get_ref(&self) -> &W { pub fn get_mut(&mut self) -> &mut W { &mut self.upstream }
&self.upstream pub fn into_inner(self) -> W { self.upstream }
}
/// Get mutable reference to upstream pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() }
pub fn get_mut(&mut self) -> &mut W { pub fn state_name(&self) -> &'static str { self.state.state_name() }
&mut self.upstream
}
/// Consume and return upstream
pub fn into_inner(self) -> W {
self.upstream
}
/// Check if stream is in poisoned state
pub fn is_poisoned(&self) -> bool {
self.state.is_poisoned()
}
/// Get current state name (for debugging)
pub fn state_name(&self) -> &'static str {
self.state.state_name()
}
/// Check if there's a pending record to write
pub fn has_pending(&self) -> bool { pub fn has_pending(&self) -> bool {
matches!(&self.state, TlsWriterState::WritingRecord { record, .. } if !record.is_empty()) matches!(&self.state, TlsWriterState::WritingRecord { record, .. } if !record.is_empty())
} }
/// Transition to poisoned state
fn poison(&mut self, error: io::Error) { fn poison(&mut self, error: io::Error) {
self.state = TlsWriterState::Poisoned { error: Some(error) }; self.state = TlsWriterState::Poisoned { error: Some(error) };
} }
/// Take error from poisoned state
fn take_poison_error(&mut self) -> io::Error { fn take_poison_error(&mut self) -> io::Error {
match &mut self.state { match &mut self.state {
TlsWriterState::Poisoned { error } => { TlsWriterState::Poisoned { error } => error.take().unwrap_or_else(|| {
error.take().unwrap_or_else(|| { io::Error::new(ErrorKind::Other, "stream previously poisoned")
io::Error::new(ErrorKind::Other, "stream previously poisoned") }),
})
}
_ => io::Error::new(ErrorKind::Other, "stream not poisoned"), _ => io::Error::new(ErrorKind::Other, "stream not poisoned"),
} }
} }
/// Build a TLS Application Data record
fn build_record(data: &[u8]) -> BytesMut { fn build_record(data: &[u8]) -> BytesMut {
let header = TlsRecordHeader { let header = TlsRecordHeader {
record_type: TLS_RECORD_APPLICATION, record_type: TLS_RECORD_APPLICATION,
@@ -689,18 +633,13 @@ impl<W> FakeTlsWriter<W> {
} }
} }
/// Result of flushing pending record
enum FlushResult { enum FlushResult {
/// All data flushed, returns payload size
Complete(usize), Complete(usize),
/// Need to wait for upstream
Pending, Pending,
/// Error occurred
Error(io::Error), Error(io::Error),
} }
impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> { impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
/// Try to flush pending record to upstream (standalone logic)
fn poll_flush_record_inner( fn poll_flush_record_inner(
upstream: &mut W, upstream: &mut W,
cx: &mut Context<'_>, cx: &mut Context<'_>,
@@ -710,19 +649,14 @@ impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
let data = record.pending(); let data = record.pending();
match Pin::new(&mut *upstream).poll_write(cx, data) { match Pin::new(&mut *upstream).poll_write(cx, data) {
Poll::Pending => return FlushResult::Pending, Poll::Pending => return FlushResult::Pending,
Poll::Ready(Err(e)) => return FlushResult::Error(e), Poll::Ready(Err(e)) => return FlushResult::Error(e),
Poll::Ready(Ok(0)) => { Poll::Ready(Ok(0)) => {
return FlushResult::Error(Error::new( return FlushResult::Error(Error::new(
ErrorKind::WriteZero, ErrorKind::WriteZero,
"upstream returned 0 bytes written" "upstream returned 0 bytes written",
)); ));
} }
Poll::Ready(Ok(n)) => record.advance(n),
Poll::Ready(Ok(n)) => {
record.advance(n);
}
} }
} }
@@ -738,7 +672,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
) -> Poll<Result<usize>> { ) -> Poll<Result<usize>> {
let this = self.get_mut(); let this = self.get_mut();
// Take ownership of state // Take ownership of state to avoid borrow conflicts.
let state = std::mem::replace(&mut this.state, TlsWriterState::Idle); let state = std::mem::replace(&mut this.state, TlsWriterState::Idle);
match state { match state {
@@ -751,7 +685,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
} }
TlsWriterState::WritingRecord { mut record, payload_size } => { TlsWriterState::WritingRecord { mut record, payload_size } => {
// Continue flushing existing record // Finish writing previous record before accepting new bytes.
match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) { match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) {
FlushResult::Pending => { FlushResult::Pending => {
this.state = TlsWriterState::WritingRecord { record, payload_size }; this.state = TlsWriterState::WritingRecord { record, payload_size };
@@ -763,7 +697,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
} }
FlushResult::Complete(_) => { FlushResult::Complete(_) => {
this.state = TlsWriterState::Idle; this.state = TlsWriterState::Idle;
// Fall through to handle new write // continue to accept new buf below
} }
} }
} }
@@ -782,19 +716,18 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
let chunk_size = buf.len().min(MAX_TLS_PAYLOAD); let chunk_size = buf.len().min(MAX_TLS_PAYLOAD);
let chunk = &buf[..chunk_size]; let chunk = &buf[..chunk_size];
// Build the complete record // Build the complete record (header + payload)
let record_data = Self::build_record(chunk); let record_data = Self::build_record(chunk);
// Try to write directly first
match Pin::new(&mut this.upstream).poll_write(cx, &record_data) { match Pin::new(&mut this.upstream).poll_write(cx, &record_data) {
Poll::Ready(Ok(n)) if n == record_data.len() => { Poll::Ready(Ok(n)) if n == record_data.len() => {
// Complete record written
Poll::Ready(Ok(chunk_size)) Poll::Ready(Ok(chunk_size))
} }
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(n)) => {
// Partial write - buffer the rest // Partial write of the record: store remainder.
let mut write_buffer = WriteBuffer::with_max_size(MAX_PENDING_WRITE); let mut write_buffer = WriteBuffer::with_max_size(MAX_PENDING_WRITE);
// record_data length is <= 16389, fits MAX_PENDING_WRITE
let _ = write_buffer.extend(&record_data[n..]); let _ = write_buffer.extend(&record_data[n..]);
this.state = TlsWriterState::WritingRecord { this.state = TlsWriterState::WritingRecord {
@@ -802,7 +735,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
payload_size: chunk_size, payload_size: chunk_size,
}; };
// We've accepted chunk_size bytes from caller // We have accepted chunk_size bytes from caller.
Poll::Ready(Ok(chunk_size)) Poll::Ready(Ok(chunk_size))
} }
@@ -812,7 +745,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
} }
Poll::Pending => { Poll::Pending => {
// Buffer the entire record // Buffer entire record and report success for this chunk.
let mut write_buffer = WriteBuffer::with_max_size(MAX_PENDING_WRITE); let mut write_buffer = WriteBuffer::with_max_size(MAX_PENDING_WRITE);
let _ = write_buffer.extend(&record_data); let _ = write_buffer.extend(&record_data);
@@ -821,10 +754,9 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
payload_size: chunk_size, payload_size: chunk_size,
}; };
// Wake to try again // Wake to retry flushing soon.
cx.waker().wake_by_ref(); cx.waker().wake_by_ref();
// We've accepted chunk_size bytes from caller
Poll::Ready(Ok(chunk_size)) Poll::Ready(Ok(chunk_size))
} }
} }
@@ -833,7 +765,6 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
let this = self.get_mut(); let this = self.get_mut();
// Take ownership of state
let state = std::mem::replace(&mut this.state, TlsWriterState::Idle); let state = std::mem::replace(&mut this.state, TlsWriterState::Idle);
match state { match state {
@@ -866,48 +797,33 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
} }
} }
// Flush upstream
Pin::new(&mut this.upstream).poll_flush(cx) Pin::new(&mut this.upstream).poll_flush(cx)
} }
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> { fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
let this = self.get_mut(); let this = self.get_mut();
// Take ownership of state
let state = std::mem::replace(&mut this.state, TlsWriterState::Idle); let state = std::mem::replace(&mut this.state, TlsWriterState::Idle);
match state { match state {
TlsWriterState::WritingRecord { mut record, payload_size } => { TlsWriterState::WritingRecord { mut record, payload_size: _ } => {
// Try to flush pending (best effort) // Best-effort flush (do not block shutdown forever).
match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) { let _ = Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record);
FlushResult::Pending => { this.state = TlsWriterState::Idle;
// Can't complete flush, continue with shutdown anyway
this.state = TlsWriterState::Idle;
}
FlushResult::Error(_) => {
// Ignore errors during shutdown
this.state = TlsWriterState::Idle;
}
FlushResult::Complete(_) => {
this.state = TlsWriterState::Idle;
}
}
} }
_ => { _ => {
this.state = TlsWriterState::Idle; this.state = TlsWriterState::Idle;
} }
} }
// Shutdown upstream
Pin::new(&mut this.upstream).poll_shutdown(cx) Pin::new(&mut this.upstream).poll_shutdown(cx)
} }
} }
impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> { impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
/// Write all data wrapped in TLS records (async method) /// Write all data wrapped in TLS records.
/// ///
/// This convenience method handles chunking large data into /// Convenience method that chunks into <= 16384 records.
/// multiple TLS records automatically.
pub async fn write_all_tls(&mut self, data: &[u8]) -> Result<()> { pub async fn write_all_tls(&mut self, data: &[u8]) -> Result<()> {
let mut written = 0; let mut written = 0;
while written < data.len() { while written < data.len() {

View File

@@ -30,20 +30,13 @@ pub fn configure_tcp_socket(
socket.set_tcp_keepalive(&keepalive)?; socket.set_tcp_keepalive(&keepalive)?;
} }
// Set buffer sizes // CHANGED: Removed manual buffer size setting (was 256KB).
set_buffer_sizes(&socket, 65536, 65536)?; // Allowing the OS kernel to handle TCP window scaling (Autotuning) is critical
// for mobile clients to avoid bufferbloat and stalled connections during uploads.
Ok(()) Ok(())
} }
/// Set socket buffer sizes
fn set_buffer_sizes(socket: &socket2::SockRef, recv: usize, send: usize) -> Result<()> {
// These may fail on some systems, so we ignore errors
let _ = socket.set_recv_buffer_size(recv);
let _ = socket.set_send_buffer_size(send);
Ok(())
}
/// Configure socket for accepting client connections /// Configure socket for accepting client connections
pub fn configure_client_socket( pub fn configure_client_socket(
stream: &TcpStream, stream: &TcpStream,
@@ -65,6 +58,8 @@ pub fn configure_client_socket(
socket.set_tcp_keepalive(&keepalive)?; socket.set_tcp_keepalive(&keepalive)?;
// Set TCP user timeout (Linux only) // Set TCP user timeout (Linux only)
// NOTE: iOS does not support TCP_USER_TIMEOUT - application-level timeout
// is implemented in relay_bidirectional instead
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
{ {
use std::os::unix::io::AsRawFd; use std::os::unix::io::AsRawFd;