50 Commits

Author SHA1 Message Date
Alexey
a688bfe22f New Relay on Tokio Copy Bidirectional 2026-02-12 20:20:01 +03:00
Alexey
9bd12f6acb 1.2.0.2 Special DC support: merge pull request #32 from telemt/1.2.0.2
1.2.0.2 Special DC support
2026-02-12 18:46:40 +03:00
Alexey
61581203c4 Semaphore + Async Magics for Defcluster
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-12 18:38:05 +03:00
Alexey
84668e671e Default Cluster Drafts
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-12 18:25:41 +03:00
Alexey
5bde202866 Startup logging refactoring: merge pull request #26 from Katze-942/main
Startup logging refactoring
2026-02-12 11:46:22 +03:00
Жора Змейкин
9304d5256a Refactor startup logging
Move all startup output (DC pings, proxy links) from println!() to
      info!() for consistent tracing format. Add reload::Layer so startup
      messages stay visible even in silent mode.
2026-02-12 05:14:23 +03:00
Alexey
364bc6e278 Merge pull request #21 from telemt/1.2.0.0
1.2.0.0
2026-02-11 17:00:46 +03:00
Alexey
e83db704b7 Pull-up 2026-02-11 16:55:18 +03:00
Alexey
acf90043eb Merge pull request #15 from telemt/main-emergency
Update README.md
2026-02-11 00:56:12 +03:00
Alexey
0011e20653 Update README.md 2026-02-11 00:55:27 +03:00
Alexey
41fb307858 Merge pull request #14 from telemt/main-emergency
Update README.md
2026-02-11 00:41:30 +03:00
Alexey
6a78c44d2e Update README.md 2026-02-11 00:41:08 +03:00
Alexey
be9c9858ac Merge pull request #13 from telemt/main-emergency
Main emergency
2026-02-11 00:39:45 +03:00
Alexey
2fa8d85b4c Update README.md 2026-02-11 00:31:45 +03:00
Alexey
310666fd44 Update README.md 2026-02-11 00:31:02 +03:00
Alexey
6cafee153a Fire-and-Forgot™ Draft
- Added fire-and-forget ignition via `--init` CLI command:
  - New `mod cli;` module handling installation logic
  - Extended `parse_cli()` to process `--init` flag (runs synchronously before tokio runtime)
  - Expanded `--help` output with installation options

- `--init` command functionality:
  - Generates random secret if not provided via `--secret`
  - Creates `/etc/telemt/config.toml` from template with user-provided or default parameters (`--port`, `--domain`, `--user`, `--config-dir`)
  - Creates hardened systemd unit `/etc/systemd/system/telemt.service` with security features:
    - `NoNewPrivileges=true`
    - `ProtectSystem=strict`
    - `PrivateTmp=true`
  - Runs `systemctl enable --now telemt.service`
  - Outputs `tg://` proxy links for the running service

- Implementation approach:
  - `--init` handled at the very start of `main()` before any async context
  - Uses blocking operations throughout (file I/O, `std::process::Command` for systemctl)
  - IP detection for tg:// links performed via blocking HTTP request
  - Command exits after installation without entering normal proxy runtime

- New CLI parameters for installation:
  - `--port` - listening port (default: 443)
  - `--domain` - TLS domain (default: auto-detected)
  - `--secret` - custom secret (default: randomly generated)
  - `--user` - systemd service user (default: telemt)
  - `--config-dir` - configuration directory (default: /etc/telemt)

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-07 20:31:49 +03:00
Alexey
32f60f34db Fix Stats + UpstreamState + EMA Latency Tracking
- Per-DC latency tracking in UpstreamState (array of 5 EMA instances, one per DC):
  - Added `dc_latency: [LatencyEma; 5]` – per‑DC tracking instead of a single global EMA
  - `effective_latency(dc_idx)` – returns DC‑specific latency, falls back to average if unavailable
  - `select_upstream(dc_idx)` – now performs latency‑weighted selection: effective_weight = config_weight × (1000 / latency_ms)
    - Example: two upstreams with equal config weight but latencies of 50ms and 200ms → selection probabilities become 80% / 20%
  - `connect(target, dc_idx)` – extended signature, dc_idx used for upstream selection and per‑DC RTT recording
  - All ping/health‑check operations now record RTT into `dc_latency[dc_zero_index]`
  - `upstream_manager.connect(dc_addr)` changed to `upstream_manager.connect(dc_addr, Some(success.dc_idx))` – DC index now participates in upstream selection and per‑DC RTT logging
  - `client.rs` – passes dc_idx when connecting to Telegram

- Summary: Upstream selection now accounts for per‑DC latency using the formula weight × (1000/ms). With multiple upstreams (e.g., direct + socks5), traffic automatically flows to the faster route for each specific DC. With a single upstream, the data is used for monitoring without affecting routing.

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-07 20:24:12 +03:00
Alexey
158eae8d2a Antireplay Improvements + DC Ping
- Fix: LruCache::get type ambiguity in stats/mod.rs
  - Changed `self.cache.get(&key.into())` to `self.cache.get(key)` (key is already &[u8], resolved via Box<[u8]>: Borrow<[u8]>)
  - Changed `self.cache.peek(&key)` / `.pop(&key)` to `.peek(key.as_ref())` / `.pop(key.as_ref())` (explicit &[u8] instead of &Box<[u8]>)

- Startup DC ping with RTT display and improved health-check (all DCs, RTT tracking, EMA latency, 30s interval):
  - Implemented `LatencyEma` – exponential moving average (α=0.3) for RTT
  - `connect()` – measures RTT of each real connection and updates EMA
  - `ping_all_dcs()` – pings all 5 DCs via each upstream, returns `Vec<StartupPingResult>` with RTT or error
  - `run_health_checks(prefer_ipv6)` – accepts IPv6 preference parameter, rotates DC between cycles (DC1→DC2→...→DC5→DC1...), interval reduced to 30s from 60s, failed checks now mark upstream as unhealthy after 3 consecutive fails
  - `DcPingResult` / `StartupPingResult` – public structures for display
  - DC Ping at startup: calls `upstream_manager.ping_all_dcs()` before accept loop, outputs table via `println!` (always visible)
  - Health checks with `prefer_ipv6`: `run_health_checks(prefer_ipv6)` receives the parameter
  - Exported `StartupPingResult` and `DcPingResult`

- Summary: Startup DC ping with RTT, rotational health-check with EMA latency tracking, 30-second interval, correct unhealthy marking after 3 fails.

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-07 20:18:25 +03:00
Alexey
92cedabc81 Zeroize for key + log refactor + fix tests
- Fixed tests that failed to compile due to mismatched generic parameters of HandshakeResult:
  - Changed `HandshakeResult<i32>` to `HandshakeResult<i32, (), ()>`
  - Changed `HandshakeResult::BadClient` to `HandshakeResult::BadClient { reader: (), writer: () }`

- Added Zeroize for all structures holding key material:
  - AesCbc – key and IV are zeroized on drop
  - SecureRandomInner – PRNG output buffer is zeroized on drop; local key copy in constructor is zeroized immediately after being passed to the cipher
  - ObfuscationParams – all four key‑material fields are zeroized on drop
  - HandshakeSuccess – all four key‑material fields are zeroized on drop

- Added protocol‑requirement documentation for legacy hashes (CodeQL suppression) in hash.rs (MD5/SHA‑1)

- Added documentation for zeroize limitations of AesCtr (opaque cipher state) in aes.rs

- Implemented silent‑mode logging and refactored initialization:
  - Added LogLevel enum to config and CLI flags --silent / --log-level
  - Added parse_cli() to handle --silent, --log-level, --help
  - Restructured main.rs initialization order: CLI → config load → determine log level → init tracing
  - Errors before tracing initialization are printed via eprintln!
  - Proxy links (tg://) are printed via println! – always visible regardless of log level
  - Configuration summary and operational messages are logged via info! (suppressed in silent mode)
  - Connection processing errors are lowered to debug! (hidden in silent mode)
  - Warning about default tls_domain moved to main (after tracing init)

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-02-07 19:49:41 +03:00
Alexey
b9428d9780 Antireplay on sliding window + SecureRandom 2026-02-07 18:26:44 +03:00
Alexey
5876f0c4d5 Update rust.yml 2026-02-07 17:58:10 +03:00
Alexey
94750a2749 Update README.md 2026-01-22 03:33:13 +03:00
Alexey
cf4b240913 Update README.md 2026-01-22 03:26:34 +03:00
Alexey
1424fbb1d5 Update README.md 2026-01-22 03:19:50 +03:00
Alexey
97f4c0d3b7 Update README.md 2026-01-22 03:17:37 +03:00
Alexey
806536fab6 Update README.md 2026-01-22 03:14:39 +03:00
Alexey
df8cfe462b Update README.md 2026-01-22 03:13:08 +03:00
Alexey
a5f1521d71 Update README.md 2026-01-22 03:07:38 +03:00
Alexey
8de7b7adc0 Update README.md 2026-01-22 03:03:19 +03:00
Alexey
cde1b15ef0 Update config.toml 2026-01-22 02:45:30 +03:00
Alexey
46e4c06ba6 Update README.md 2026-01-22 01:59:18 +03:00
Alexey
b7673daf0f Update README.md 2026-01-22 01:57:44 +03:00
Alexey
397ed8f193 Update README.md 2026-01-22 01:56:42 +03:00
Alexey
d90b2fd300 Update README.md 2026-01-22 01:55:31 +03:00
Alexey
d62136d9fa Update README.md 2026-01-22 01:53:05 +03:00
Alexey
0f8933b908 Update README.md 2026-01-22 01:48:37 +03:00
Alexey
0ec87974d1 Update README.md 2026-01-22 01:47:43 +03:00
Alexey
c8446c32d1 Update README.md 2026-01-22 01:46:28 +03:00
Alexey
f79a2eb097 Update README.md 2026-01-22 01:26:36 +03:00
Alexey
dea1a3b5de Update README.md 2026-01-22 01:16:46 +03:00
Alexey
97ce235ae4 Update README.md 2026-01-22 01:16:35 +03:00
Alexey
d04757eb9c Update README.md 2026-01-20 11:13:33 +03:00
Alexey
2d7901a978 Update README.md 2026-01-20 11:09:24 +03:00
Alexey
3881ba9bed 1.1.1.0 2026-01-20 02:09:56 +03:00
Alexey
5ac9089ccb Update README.md 2026-01-20 01:39:59 +03:00
Alexey
eb8b991818 Update README.md 2026-01-20 01:32:39 +03:00
Alexey
2ce8fbb2cc 1.1.0.0 2026-01-20 01:20:02 +03:00
Alexey
038f0cd5d1 Update README.md 2026-01-19 23:52:31 +03:00
Alexey
efea3f981d Update README.md 2026-01-19 23:51:43 +03:00
Alexey
42ce9dd671 Update README.md 2026-01-12 22:11:21 +03:00
27 changed files with 5245 additions and 1233 deletions

View File

@@ -14,6 +14,11 @@ jobs:
name: Build name: Build
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
contents: read
actions: write
checks: write
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4

2742
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,15 +1,14 @@
[package] [package]
name = "telemt" name = "telemt"
version = "1.0.0" version = "1.2.0"
edition = "2021" edition = "2024"
rust-version = "1.75"
[dependencies] [dependencies]
# C # C
libc = "0.2" libc = "0.2"
# Async runtime # Async runtime
tokio = { version = "1.35", features = ["full", "tracing"] } tokio = { version = "1.42", features = ["full", "tracing"] }
tokio-util = { version = "0.7", features = ["codec"] } tokio-util = { version = "0.7", features = ["codec"] }
# Crypto # Crypto
@@ -20,41 +19,41 @@ sha2 = "0.10"
sha1 = "0.10" sha1 = "0.10"
md-5 = "0.10" md-5 = "0.10"
hmac = "0.12" hmac = "0.12"
crc32fast = "1.3" crc32fast = "1.4"
zeroize = { version = "1.8", features = ["derive"] }
# Network # Network
socket2 = { version = "0.5", features = ["all"] } socket2 = { version = "0.5", features = ["all"] }
rustls = "0.22"
# Serial # Serialization
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
toml = "0.8" toml = "0.8"
# Utils # Utils
bytes = "1.5" bytes = "1.9"
thiserror = "1.0" thiserror = "2.0"
tracing = "0.1" tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] }
parking_lot = "0.12" parking_lot = "0.12"
dashmap = "5.5" dashmap = "5.5"
lru = "0.12" lru = "0.12"
rand = "0.8" rand = "0.9"
chrono = { version = "0.4", features = ["serde"] } chrono = { version = "0.4", features = ["serde"] }
hex = "0.4" hex = "0.4"
base64 = "0.21" base64 = "0.22"
url = "2.5" url = "2.5"
regex = "1.10" regex = "1.11"
once_cell = "1.19"
crossbeam-queue = "0.3" crossbeam-queue = "0.3"
# HTTP # HTTP
reqwest = { version = "0.11", features = ["rustls-tls"], default-features = false } reqwest = { version = "0.12", features = ["rustls-tls"], default-features = false }
[dev-dependencies] [dev-dependencies]
tokio-test = "0.4" tokio-test = "0.4"
criterion = "0.5" criterion = "0.5"
proptest = "1.4" proptest = "1.4"
futures = "0.3"
[[bench]] [[bench]]
name = "crypto_bench" name = "crypto_bench"

214
README.md
View File

@@ -2,6 +2,30 @@
**Telemt** is a fast, secure, and feature-rich server written in Rust: it fully implements the official Telegram proxy algo and adds many production-ready improvements such as connection pooling, replay protection, detailed statistics, masking from "prying" eyes **Telemt** is a fast, secure, and feature-rich server written in Rust: it fully implements the official Telegram proxy algo and adds many production-ready improvements such as connection pooling, replay protection, detailed statistics, masking from "prying" eyes
## Emergency
**Важное сообщение для пользователей из России**
Мы работаем над проектом с Нового года и сейчас готовим новый релиз - 1.2
В нём имплементируется поддержка Middle Proxy Protocol - основного терминатора для Ad Tag:
работа над ним идёт с 6 ферваля, а уже 10 февраля произошли "громкие события"...
Если у вас есть компетенции в асинхронных сетевых приложениях - мы открыты к предложениям и pull requests
**Important message for users from Russia**
We've been working on the project since December 30 and are currently preparing a new release 1.2
It implements support for the Middle Proxy Protocol the primary point for the Ad Tag:
development on it started on February 6th, and by February 10th, "big activity" in Russia had already "taken place"...
If you have expertise in asynchronous network applications we are open to ideas and pull requests!
# Features
💥 The configuration structure has changed since version 1.1.0.0, change it in your environment!
⚓ Our implementation of **TLS-fronting** is one of the most deeply debugged, focused, advanced and *almost* **"behaviorally consistent to real"**: we are confident we have it right - [see evidence on our validation and traces](#recognizability-for-dpi-and-crawler)
# GOTO # GOTO
- [Features](#features) - [Features](#features)
- [Quick Start Guide](#quick-start-guide) - [Quick Start Guide](#quick-start-guide)
@@ -16,6 +40,7 @@
- [IP](#bind-on-ip) - [IP](#bind-on-ip)
- [SOCKS](#socks45-as-upstream) - [SOCKS](#socks45-as-upstream)
- [FAQ](#faq) - [FAQ](#faq)
- [Recognizability for DPI + crawler](#recognizability-for-dpi-and-crawler)
- [Telegram Calls](#telegram-calls-via-mtproxy) - [Telegram Calls](#telegram-calls-via-mtproxy)
- [DPI](#how-does-dpi-see-mtproxy-tls) - [DPI](#how-does-dpi-see-mtproxy-tls)
- [Whitelist on Network Level](#whitelist-on-ip) - [Whitelist on Network Level](#whitelist-on-ip)
@@ -118,44 +143,100 @@ then Ctrl+X -> Y -> Enter to save
## Configuration ## Configuration
### Minimal Configuration for First Start ### Minimal Configuration for First Start
```toml ```toml
port = 443 # Listening port # === UI ===
show_links = ["tele", "hello"] # Specify users, for whom will be displayed the links # Users to show in the startup log (tg:// links)
show_link = ["hello"]
[users] # === General Settings ===
tele = "00000000000000000000000000000000" # Replace the secret with one generated before [general]
hello = "00000000000000000000000000000000" # Replace the secret with one generated before prefer_ipv6 = false
fast_mode = true
use_middle_proxy = false
# ad_tag = "..."
[modes] [general.modes]
classic = false # Plain obfuscated mode classic = false
secure = false # dd-prefix mode secure = false
tls = true # Fake TLS - ee-prefix tls = true
tls_domain = "petrovich.ru" # Domain for ee-secret and masking # === Server Binding ===
mask = true # Enable masking of bad traffic [server]
mask_host = "petrovich.ru" # Optional override for mask destination port = 443
mask_port = 443 # Port for masking listen_addr_ipv4 = "0.0.0.0"
listen_addr_ipv6 = "::"
# metrics_port = 9090
# metrics_whitelist = ["127.0.0.1", "::1"]
prefer_ipv6 = false # Try IPv6 DCs first if true # Listen on multiple interfaces/IPs (overrides listen_addr_*)
fast_mode = true # Use "fast" obfuscation variant [[server.listeners]]
ip = "0.0.0.0"
# announce_ip = "1.2.3.4" # Optional: Public IP for tg:// links
client_keepalive = 600 # Seconds [[server.listeners]]
client_ack_timeout = 300 # Seconds ip = "::"
# === Timeouts (in seconds) ===
[timeouts]
client_handshake = 15
tg_connect = 10
client_keepalive = 60
client_ack = 300
# === Anti-Censorship & Masking ===
[censorship]
tls_domain = "petrovich.ru"
mask = true
mask_port = 443
# mask_host = "petrovich.ru" # Defaults to tls_domain if not set
fake_cert_len = 2048
# === Access Control & Users ===
# username "hello" is used for example
[access]
replay_check_len = 65536
ignore_time_skew = false
[access.users]
# format: "username" = "32_hex_chars_secret"
hello = "00000000000000000000000000000000"
# [access.user_max_tcp_conns]
# hello = 50
# [access.user_data_quota]
# hello = 1073741824 # 1 GB
# === Upstreams & Routing ===
# By default, direct connection is used, but you can add SOCKS proxy
# Direct - Default
[[upstreams]]
type = "direct"
enabled = true
weight = 10
# SOCKS5
# [[upstreams]]
# type = "socks5"
# address = "127.0.0.1:9050"
# enabled = false
# weight = 1
``` ```
### Advanced ### Advanced
#### Adtag #### Adtag
To use channel advertising and usage statistics from Telegram, get Adtag from [@mtproxybot](https://t.me/mtproxybot), add this parameter to the end of config.toml and specify it To use channel advertising and usage statistics from Telegram, get Adtag from [@mtproxybot](https://t.me/mtproxybot), add this parameter to section `[General]`
```toml ```toml
ad_tag = "00000000000000000000000000000000" # Replace zeros to your adtag from @mtproxybot ad_tag = "00000000000000000000000000000000" # Replace zeros to your adtag from @mtproxybot
``` ```
#### Listening and Announce IPs #### Listening and Announce IPs
To specify listening address and/or address in links, add to the end of config.toml: To specify listening address and/or address in links, add to section `[[server.listeners]]` of config.toml:
```toml ```toml
[[listeners]] [[server.listeners]]
ip = "0.0.0.0" # 0.0.0.0 = all IPs; your IP = specific listening ip = "0.0.0.0" # 0.0.0.0 = all IPs; your IP = specific listening
announce_ip = "1.2.3.4" # IP in links; comment with # if not used announce_ip = "1.2.3.4" # IP in links; comment with # if not used
``` ```
#### Upstream Manager #### Upstream Manager
To specify upstream, add to the end of config.toml: To specify upstream, add to section `[[upstreams]]` of config.toml:
##### Bind on IP ##### Bind on IP
```toml ```toml
[[upstreams]] [[upstreams]]
@@ -186,6 +267,95 @@ enabled = true
``` ```
## FAQ ## FAQ
### Recognizability for DPI and crawler
Since version 1.1.0.0, we have debugged masking perfectly: for all clients without "presenting" a key,
we transparently direct traffic to the target host!
- We consider this a breakthrough aspect, which has no stable analogues today
- Based on this: if `telemt` configured correctly, **TLS mode is completely identical to real-life handshake + communication** with a specified host
- Here is our evidence:
- 212.220.88.77 - "dummy" host, running `telemt`
- `petrovich.ru` - `tls` + `masking` host, in HEX: `706574726f766963682e7275`
- **No MITM + No Fake Certificates/Crypto** = pure transparent *TCP Splice* to "best" upstream: MTProxy or tls/mask-host:
- DPI see legitimate HTTPS to `tls_host`, including *valid chain-of-trust* and entropy
- Crawlers completely satisfied receiving responses from `mask_host`
#### Client WITH secret-key accesses the MTProxy resource:
<img width="360" height="439" alt="telemt" src="https://github.com/user-attachments/assets/39352afb-4a11-4ecc-9d91-9e8cfb20607d" />
#### Client WITHOUT secret-key gets transparent access to the specified resource:
- with trusted certificate
- with original handshake
- with full request-response way
- with low-latency overhead
```bash
root@debian:~/telemt# curl -v -I --resolve petrovich.ru:443:212.220.88.77 https://petrovich.ru/
* Added petrovich.ru:443:212.220.88.77 to DNS cache
* Hostname petrovich.ru was found in DNS cache
* Trying 212.220.88.77:443...
* Connected to petrovich.ru (212.220.88.77) port 443 (#0)
* ALPN: offers h2,http/1.1
* TLSv1.3 (OUT), TLS handshake, Client hello (1):
* CAfile: /etc/ssl/certs/ca-certificates.crt
* CApath: /etc/ssl/certs
* TLSv1.3 (IN), TLS handshake, Server hello (2):
* TLSv1.3 (IN), TLS handshake, Encrypted Extensions (8):
* TLSv1.3 (IN), TLS handshake, Certificate (11):
* TLSv1.3 (IN), TLS handshake, CERT verify (15):
* TLSv1.3 (IN), TLS handshake, Finished (20):
* TLSv1.3 (OUT), TLS change cipher, Change cipher spec (1):
* TLSv1.3 (OUT), TLS handshake, Finished (20):
* SSL connection using TLSv1.3 / TLS_AES_256_GCM_SHA384
* ALPN: server did not agree on a protocol. Uses default.
* Server certificate:
* subject: C=RU; ST=Saint Petersburg; L=Saint Petersburg; O=STD Petrovich; CN=*.petrovich.ru
* start date: Jan 28 11:21:01 2025 GMT
* expire date: Mar 1 11:21:00 2026 GMT
* subjectAltName: host "petrovich.ru" matched cert's "petrovich.ru"
* issuer: C=BE; O=GlobalSign nv-sa; CN=GlobalSign RSA OV SSL CA 2018
* SSL certificate verify ok.
* using HTTP/1.x
> HEAD / HTTP/1.1
> Host: petrovich.ru
> User-Agent: curl/7.88.1
> Accept: */*
>
* TLSv1.3 (IN), TLS handshake, Newsession Ticket (4):
* TLSv1.3 (IN), TLS handshake, Newsession Ticket (4):
* old SSL session ID is stale, removing
< HTTP/1.1 200 OK
HTTP/1.1 200 OK
< Server: Variti/0.9.3a
Server: Variti/0.9.3a
< Date: Thu, 01 Jan 2026 00:0000 GMT
Date: Thu, 01 Jan 2026 00:0000 GMT
< Access-Control-Allow-Origin: *
Access-Control-Allow-Origin: *
< Content-Type: text/html
Content-Type: text/html
< Cache-Control: no-store
Cache-Control: no-store
< Expires: Thu, 01 Jan 2026 00:0000 GMT
Expires: Thu, 01 Jan 2026 00:0000 GMT
< Pragma: no-cache
Pragma: no-cache
< Set-Cookie: ipp_uid=XXXXX/XXXXX/XXXXX==; Expires=Tue, 31 Dec 2040 23:59:59 GMT; Domain=.petrovich.ru; Path=/
Set-Cookie: ipp_uid=XXXXX/XXXXX/XXXXX==; Expires=Tue, 31 Dec 2040 23:59:59 GMT; Domain=.petrovich.ru; Path=/
< Content-Type: text/html
Content-Type: text/html
< Content-Length: 31253
Content-Length: 31253
< Connection: keep-alive
Connection: keep-alive
< Keep-Alive: timeout=60
Keep-Alive: timeout=60
<
* Connection #0 to host petrovich.ru left intact
```
- We challenged ourselves, we kept trying and we didn't only *beat the air*: now, we have something to show you
- Do not just take our word for it? - This is great and we respect that: you can build your own `telemt` or download a build and check it right now
### Telegram Calls via MTProxy ### Telegram Calls via MTProxy
- Telegram architecture **does NOT allow calls via MTProxy**, but only via SOCKS5, which cannot be obfuscated - Telegram architecture **does NOT allow calls via MTProxy**, but only via SOCKS5, which cannot be obfuscated
### How does DPI see MTProxy TLS? ### How does DPI see MTProxy TLS?
@@ -233,7 +403,7 @@ telemt config.toml
## Issues ## Issues
- ✅ [SOCKS5 as Upstream](https://github.com/telemt/telemt/issues/1) -> added Upstream Management - ✅ [SOCKS5 as Upstream](https://github.com/telemt/telemt/issues/1) -> added Upstream Management
- [iOS - Media Upload Hanging-in-Loop](https://github.com/telemt/telemt/issues/2) - [iOS - Media Upload Hanging-in-Loop](https://github.com/telemt/telemt/issues/2)
## Roadmap ## Roadmap
- Public IP in links - Public IP in links

View File

@@ -1,13 +1,79 @@
port = 443 # === UI ===
# Users to show in the startup log (tg:// links)
show_link = ["hello"]
[users] # === General Settings ===
user1 = "00000000000000000000000000000000" [general]
prefer_ipv6 = false
fast_mode = true
use_middle_proxy = true
ad_tag = "00000000000000000000000000000000"
[modes] # Log level: debug | verbose | normal | silent
classic = true # Can be overridden with --silent or --log-level CLI flags
secure = true # RUST_LOG env var takes absolute priority over all of these
log_level = "normal"
[general.modes]
classic = false
secure = false
tls = true tls = true
tls_domain = "www.github.com" # === Server Binding ===
fast_mode = true [server]
prefer_ipv6 = false port = 443
listen_addr_ipv4 = "0.0.0.0"
listen_addr_ipv6 = "::"
# metrics_port = 9090
# metrics_whitelist = ["127.0.0.1", "::1"]
# Listen on multiple interfaces/IPs (overrides listen_addr_*)
[[server.listeners]]
ip = "0.0.0.0"
# announce_ip = "1.2.3.4" # Optional: Public IP for tg:// links
[[server.listeners]]
ip = "::"
# === Timeouts (in seconds) ===
[timeouts]
client_handshake = 15
tg_connect = 10
client_keepalive = 60
client_ack = 300
# === Anti-Censorship & Masking ===
[censorship]
tls_domain = "google.ru"
mask = true
mask_port = 443
# mask_host = "petrovich.ru" # Defaults to tls_domain if not set
fake_cert_len = 2048
# === Access Control & Users ===
[access]
replay_check_len = 65536
replay_window_secs = 1800
ignore_time_skew = false
[access.users]
# format: "username" = "32_hex_chars_secret"
hello = "00000000000000000000000000000000"
# [access.user_max_tcp_conns]
# hello = 50
# [access.user_data_quota]
# hello = 1073741824 # 1 GB
# === Upstreams & Routing ===
[[upstreams]]
type = "direct"
enabled = true
weight = 10
# [[upstreams]]
# type = "socks5"
# address = "127.0.0.1:9050"
# enabled = false
# weight = 1

300
src/cli.rs Normal file
View File

@@ -0,0 +1,300 @@
//! CLI commands: --init (fire-and-forget setup)
use std::fs;
use std::path::{Path, PathBuf};
use std::process::Command;
use rand::Rng;
/// Options for the init command
pub struct InitOptions {
pub port: u16,
pub domain: String,
pub secret: Option<String>,
pub username: String,
pub config_dir: PathBuf,
pub no_start: bool,
}
impl Default for InitOptions {
fn default() -> Self {
Self {
port: 443,
domain: "www.google.com".to_string(),
secret: None,
username: "user".to_string(),
config_dir: PathBuf::from("/etc/telemt"),
no_start: false,
}
}
}
/// Parse --init subcommand options from CLI args.
///
/// Returns `Some(InitOptions)` if `--init` was found, `None` otherwise.
pub fn parse_init_args(args: &[String]) -> Option<InitOptions> {
if !args.iter().any(|a| a == "--init") {
return None;
}
let mut opts = InitOptions::default();
let mut i = 0;
while i < args.len() {
match args[i].as_str() {
"--port" => {
i += 1;
if i < args.len() {
opts.port = args[i].parse().unwrap_or(443);
}
}
"--domain" => {
i += 1;
if i < args.len() {
opts.domain = args[i].clone();
}
}
"--secret" => {
i += 1;
if i < args.len() {
opts.secret = Some(args[i].clone());
}
}
"--user" => {
i += 1;
if i < args.len() {
opts.username = args[i].clone();
}
}
"--config-dir" => {
i += 1;
if i < args.len() {
opts.config_dir = PathBuf::from(&args[i]);
}
}
"--no-start" => {
opts.no_start = true;
}
_ => {}
}
i += 1;
}
Some(opts)
}
/// Run the fire-and-forget setup.
pub fn run_init(opts: InitOptions) -> Result<(), Box<dyn std::error::Error>> {
eprintln!("[telemt] Fire-and-forget setup");
eprintln!();
// 1. Generate or validate secret
let secret = match opts.secret {
Some(s) => {
if s.len() != 32 || !s.chars().all(|c| c.is_ascii_hexdigit()) {
eprintln!("[error] Secret must be exactly 32 hex characters");
std::process::exit(1);
}
s
}
None => generate_secret(),
};
eprintln!("[+] Secret: {}", secret);
eprintln!("[+] User: {}", opts.username);
eprintln!("[+] Port: {}", opts.port);
eprintln!("[+] Domain: {}", opts.domain);
// 2. Create config directory
fs::create_dir_all(&opts.config_dir)?;
let config_path = opts.config_dir.join("config.toml");
// 3. Write config
let config_content = generate_config(&opts.username, &secret, opts.port, &opts.domain);
fs::write(&config_path, &config_content)?;
eprintln!("[+] Config written to {}", config_path.display());
// 4. Write systemd unit
let exe_path = std::env::current_exe()
.unwrap_or_else(|_| PathBuf::from("/usr/local/bin/telemt"));
let unit_path = Path::new("/etc/systemd/system/telemt.service");
let unit_content = generate_systemd_unit(&exe_path, &config_path);
match fs::write(unit_path, &unit_content) {
Ok(()) => {
eprintln!("[+] Systemd unit written to {}", unit_path.display());
}
Err(e) => {
eprintln!("[!] Cannot write systemd unit (run as root?): {}", e);
eprintln!("[!] Manual unit file content:");
eprintln!("{}", unit_content);
// Still print links and config
print_links(&opts.username, &secret, opts.port, &opts.domain);
return Ok(());
}
}
// 5. Reload systemd
run_cmd("systemctl", &["daemon-reload"]);
// 6. Enable service
run_cmd("systemctl", &["enable", "telemt.service"]);
eprintln!("[+] Service enabled");
// 7. Start service (unless --no-start)
if !opts.no_start {
run_cmd("systemctl", &["start", "telemt.service"]);
eprintln!("[+] Service started");
// Brief delay then check status
std::thread::sleep(std::time::Duration::from_secs(1));
let status = Command::new("systemctl")
.args(["is-active", "telemt.service"])
.output();
match status {
Ok(out) if out.status.success() => {
eprintln!("[+] Service is running");
}
_ => {
eprintln!("[!] Service may not have started correctly");
eprintln!("[!] Check: journalctl -u telemt.service -n 20");
}
}
} else {
eprintln!("[+] Service not started (--no-start)");
eprintln!("[+] Start manually: systemctl start telemt.service");
}
eprintln!();
// 8. Print links
print_links(&opts.username, &secret, opts.port, &opts.domain);
Ok(())
}
fn generate_secret() -> String {
let mut rng = rand::rng();
let bytes: Vec<u8> = (0..16).map(|_| rng.random::<u8>()).collect();
hex::encode(bytes)
}
fn generate_config(username: &str, secret: &str, port: u16, domain: &str) -> String {
format!(
r#"# Telemt MTProxy — auto-generated config
# Re-run `telemt --init` to regenerate
show_link = ["{username}"]
[general]
prefer_ipv6 = false
fast_mode = true
use_middle_proxy = false
log_level = "normal"
[general.modes]
classic = false
secure = false
tls = true
[server]
port = {port}
listen_addr_ipv4 = "0.0.0.0"
listen_addr_ipv6 = "::"
[[server.listeners]]
ip = "0.0.0.0"
[[server.listeners]]
ip = "::"
[timeouts]
client_handshake = 15
tg_connect = 10
client_keepalive = 60
client_ack = 300
[censorship]
tls_domain = "{domain}"
mask = true
mask_port = 443
fake_cert_len = 2048
[access]
replay_check_len = 65536
replay_window_secs = 1800
ignore_time_skew = false
[access.users]
{username} = "{secret}"
[[upstreams]]
type = "direct"
enabled = true
weight = 10
"#,
username = username,
secret = secret,
port = port,
domain = domain,
)
}
fn generate_systemd_unit(exe_path: &Path, config_path: &Path) -> String {
format!(
r#"[Unit]
Description=Telemt MTProxy
Documentation=https://github.com/nicepkg/telemt
After=network-online.target
Wants=network-online.target
[Service]
Type=simple
ExecStart={exe} {config}
Restart=always
RestartSec=5
LimitNOFILE=65535
# Security hardening
NoNewPrivileges=true
ProtectSystem=strict
ProtectHome=true
ReadWritePaths=/etc/telemt
PrivateTmp=true
[Install]
WantedBy=multi-user.target
"#,
exe = exe_path.display(),
config = config_path.display(),
)
}
fn run_cmd(cmd: &str, args: &[&str]) {
match Command::new(cmd).args(args).output() {
Ok(output) => {
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
eprintln!("[!] {} {} failed: {}", cmd, args.join(" "), stderr.trim());
}
}
Err(e) => {
eprintln!("[!] Failed to run {} {}: {}", cmd, args.join(" "), e);
}
}
}
fn print_links(username: &str, secret: &str, port: u16, domain: &str) {
let domain_hex = hex::encode(domain);
println!("=== Proxy Links ===");
println!("[{}]", username);
println!(" EE-TLS: tg://proxy?server=YOUR_SERVER_IP&port={}&secret=ee{}{}",
port, secret, domain_hex);
println!();
println!("Replace YOUR_SERVER_IP with your server's public IP.");
println!("The proxy will auto-detect and display the correct link on startup.");
println!("Check: journalctl -u telemt.service | head -30");
println!("===================");
}

View File

@@ -1,12 +1,89 @@
//! Configuration //! Configuration
use std::collections::HashMap; use std::collections::HashMap;
use std::net::IpAddr; use std::net::{IpAddr, SocketAddr};
use std::path::Path; use std::path::Path;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
// ============= Helper Defaults =============
fn default_true() -> bool { true }
fn default_port() -> u16 { 443 }
fn default_tls_domain() -> String { "www.google.com".to_string() }
fn default_mask_port() -> u16 { 443 }
fn default_replay_check_len() -> usize { 65536 }
fn default_replay_window_secs() -> u64 { 1800 }
fn default_handshake_timeout() -> u64 { 15 }
fn default_connect_timeout() -> u64 { 10 }
fn default_keepalive() -> u64 { 60 }
fn default_ack_timeout() -> u64 { 300 }
fn default_listen_addr() -> String { "0.0.0.0".to_string() }
fn default_fake_cert_len() -> usize { 2048 }
fn default_weight() -> u16 { 1 }
fn default_metrics_whitelist() -> Vec<IpAddr> {
vec![
"127.0.0.1".parse().unwrap(),
"::1".parse().unwrap(),
]
}
// ============= Log Level =============
/// Logging verbosity level
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum LogLevel {
/// All messages including trace (trace + debug + info + warn + error)
Debug,
/// Detailed operational logs (debug + info + warn + error)
Verbose,
/// Standard operational logs (info + warn + error)
#[default]
Normal,
/// Minimal output: only warnings and errors (warn + error).
/// Startup messages (config, DC connectivity, proxy links) are always shown
/// via info! before the filter is applied.
Silent,
}
impl LogLevel {
/// Convert to tracing EnvFilter directive string
pub fn to_filter_str(&self) -> &'static str {
match self {
LogLevel::Debug => "trace",
LogLevel::Verbose => "debug",
LogLevel::Normal => "info",
LogLevel::Silent => "warn",
}
}
/// Parse from a loose string (CLI argument)
pub fn from_str_loose(s: &str) -> Self {
match s.to_lowercase().as_str() {
"debug" | "trace" => LogLevel::Debug,
"verbose" => LogLevel::Verbose,
"normal" | "info" => LogLevel::Normal,
"silent" | "quiet" | "error" | "warn" => LogLevel::Silent,
_ => LogLevel::Normal,
}
}
}
impl std::fmt::Display for LogLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LogLevel::Debug => write!(f, "debug"),
LogLevel::Verbose => write!(f, "verbose"),
LogLevel::Normal => write!(f, "normal"),
LogLevel::Silent => write!(f, "silent"),
}
}
}
// ============= Sub-Configs =============
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProxyModes { pub struct ProxyModes {
#[serde(default)] #[serde(default)]
@@ -17,26 +94,193 @@ pub struct ProxyModes {
pub tls: bool, pub tls: bool,
} }
fn default_true() -> bool { true }
fn default_weight() -> u16 { 1 }
impl Default for ProxyModes { impl Default for ProxyModes {
fn default() -> Self { fn default() -> Self {
Self { classic: true, secure: true, tls: true } Self { classic: true, secure: true, tls: true }
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneralConfig {
#[serde(default)]
pub modes: ProxyModes,
#[serde(default)]
pub prefer_ipv6: bool,
#[serde(default = "default_true")]
pub fast_mode: bool,
#[serde(default)]
pub use_middle_proxy: bool,
#[serde(default)]
pub ad_tag: Option<String>,
#[serde(default)]
pub log_level: LogLevel,
}
impl Default for GeneralConfig {
fn default() -> Self {
Self {
modes: ProxyModes::default(),
prefer_ipv6: false,
fast_mode: true,
use_middle_proxy: false,
ad_tag: None,
log_level: LogLevel::Normal,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
#[serde(default = "default_port")]
pub port: u16,
#[serde(default = "default_listen_addr")]
pub listen_addr_ipv4: String,
#[serde(default)]
pub listen_addr_ipv6: Option<String>,
#[serde(default)]
pub listen_unix_sock: Option<String>,
#[serde(default)]
pub metrics_port: Option<u16>,
#[serde(default = "default_metrics_whitelist")]
pub metrics_whitelist: Vec<IpAddr>,
#[serde(default)]
pub listeners: Vec<ListenerConfig>,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
port: default_port(),
listen_addr_ipv4: default_listen_addr(),
listen_addr_ipv6: Some("::".to_string()),
listen_unix_sock: None,
metrics_port: None,
metrics_whitelist: default_metrics_whitelist(),
listeners: Vec::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TimeoutsConfig {
#[serde(default = "default_handshake_timeout")]
pub client_handshake: u64,
#[serde(default = "default_connect_timeout")]
pub tg_connect: u64,
#[serde(default = "default_keepalive")]
pub client_keepalive: u64,
#[serde(default = "default_ack_timeout")]
pub client_ack: u64,
}
impl Default for TimeoutsConfig {
fn default() -> Self {
Self {
client_handshake: default_handshake_timeout(),
tg_connect: default_connect_timeout(),
client_keepalive: default_keepalive(),
client_ack: default_ack_timeout(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AntiCensorshipConfig {
#[serde(default = "default_tls_domain")]
pub tls_domain: String,
#[serde(default = "default_true")]
pub mask: bool,
#[serde(default)]
pub mask_host: Option<String>,
#[serde(default = "default_mask_port")]
pub mask_port: u16,
#[serde(default = "default_fake_cert_len")]
pub fake_cert_len: usize,
}
impl Default for AntiCensorshipConfig {
fn default() -> Self {
Self {
tls_domain: default_tls_domain(),
mask: true,
mask_host: None,
mask_port: default_mask_port(),
fake_cert_len: default_fake_cert_len(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessConfig {
#[serde(default)]
pub users: HashMap<String, String>,
#[serde(default)]
pub user_max_tcp_conns: HashMap<String, usize>,
#[serde(default)]
pub user_expirations: HashMap<String, DateTime<Utc>>,
#[serde(default)]
pub user_data_quota: HashMap<String, u64>,
#[serde(default = "default_replay_check_len")]
pub replay_check_len: usize,
#[serde(default = "default_replay_window_secs")]
pub replay_window_secs: u64,
#[serde(default)]
pub ignore_time_skew: bool,
}
impl Default for AccessConfig {
fn default() -> Self {
let mut users = HashMap::new();
users.insert("default".to_string(), "00000000000000000000000000000000".to_string());
Self {
users,
user_max_tcp_conns: HashMap::new(),
user_expirations: HashMap::new(),
user_data_quota: HashMap::new(),
replay_check_len: default_replay_check_len(),
replay_window_secs: default_replay_window_secs(),
ignore_time_skew: false,
}
}
}
// ============= Aux Structures =============
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "lowercase")] #[serde(tag = "type", rename_all = "lowercase")]
pub enum UpstreamType { pub enum UpstreamType {
Direct { Direct {
#[serde(default)] #[serde(default)]
interface: Option<String>, // Bind to specific IP/Interface interface: Option<String>,
}, },
Socks4 { Socks4 {
address: String, // IP:Port of SOCKS server address: String,
#[serde(default)] #[serde(default)]
interface: Option<String>, // Bind to specific IP/Interface for connection to SOCKS interface: Option<String>,
#[serde(default)] #[serde(default)]
user_id: Option<String>, user_id: Option<String>,
}, },
@@ -65,158 +309,48 @@ pub struct UpstreamConfig {
pub struct ListenerConfig { pub struct ListenerConfig {
pub ip: IpAddr, pub ip: IpAddr,
#[serde(default)] #[serde(default)]
pub announce_ip: Option<IpAddr>, // IP to show in tg:// links pub announce_ip: Option<IpAddr>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] // ============= Main Config =============
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ProxyConfig { pub struct ProxyConfig {
#[serde(default = "default_port")] #[serde(default)]
pub port: u16, pub general: GeneralConfig,
#[serde(default)] #[serde(default)]
pub users: HashMap<String, String>, pub server: ServerConfig,
#[serde(default)] #[serde(default)]
pub ad_tag: Option<String>, pub timeouts: TimeoutsConfig,
#[serde(default)] #[serde(default)]
pub modes: ProxyModes, pub censorship: AntiCensorshipConfig,
#[serde(default = "default_tls_domain")]
pub tls_domain: String,
#[serde(default = "default_true")]
pub mask: bool,
#[serde(default)] #[serde(default)]
pub mask_host: Option<String>, pub access: AccessConfig,
#[serde(default = "default_mask_port")]
pub mask_port: u16,
#[serde(default)]
pub prefer_ipv6: bool,
#[serde(default = "default_true")]
pub fast_mode: bool,
#[serde(default)]
pub use_middle_proxy: bool,
#[serde(default)]
pub user_max_tcp_conns: HashMap<String, usize>,
#[serde(default)]
pub user_expirations: HashMap<String, DateTime<Utc>>,
#[serde(default)]
pub user_data_quota: HashMap<String, u64>,
#[serde(default = "default_replay_check_len")]
pub replay_check_len: usize,
#[serde(default)]
pub ignore_time_skew: bool,
#[serde(default = "default_handshake_timeout")]
pub client_handshake_timeout: u64,
#[serde(default = "default_connect_timeout")]
pub tg_connect_timeout: u64,
#[serde(default = "default_keepalive")]
pub client_keepalive: u64,
#[serde(default = "default_ack_timeout")]
pub client_ack_timeout: u64,
#[serde(default = "default_listen_addr")]
pub listen_addr_ipv4: String,
#[serde(default)]
pub listen_addr_ipv6: Option<String>,
#[serde(default)]
pub listen_unix_sock: Option<String>,
#[serde(default)]
pub metrics_port: Option<u16>,
#[serde(default = "default_metrics_whitelist")]
pub metrics_whitelist: Vec<IpAddr>,
#[serde(default = "default_fake_cert_len")]
pub fake_cert_len: usize,
// New fields
#[serde(default)] #[serde(default)]
pub upstreams: Vec<UpstreamConfig>, pub upstreams: Vec<UpstreamConfig>,
#[serde(default)]
pub listeners: Vec<ListenerConfig>,
#[serde(default)] #[serde(default)]
pub show_link: Vec<String>, pub show_link: Vec<String>,
}
fn default_port() -> u16 { 443 } /// DC address overrides for non-standard DCs (CDN, media, test, etc.)
fn default_tls_domain() -> String { "www.google.com".to_string() } /// Keys are DC indices as strings, values are "ip:port" addresses.
fn default_mask_port() -> u16 { 443 } /// Matches the C implementation's `proxy_for <dc_id> <ip>:<port>` config directive.
fn default_replay_check_len() -> usize { 65536 } /// Example in config.toml:
// CHANGED: Increased handshake timeout for bad mobile networks /// [dc_overrides]
fn default_handshake_timeout() -> u64 { 15 } /// "203" = "149.154.175.100:443"
fn default_connect_timeout() -> u64 { 10 } #[serde(default)]
// CHANGED: Reduced keepalive from 600s to 60s. pub dc_overrides: HashMap<String, String>,
// Mobile NATs often drop idle connections after 60-120s.
fn default_keepalive() -> u64 { 60 }
fn default_ack_timeout() -> u64 { 300 }
fn default_listen_addr() -> String { "0.0.0.0".to_string() }
fn default_fake_cert_len() -> usize { 2048 }
fn default_metrics_whitelist() -> Vec<IpAddr> { /// Default DC index (1-5) for unmapped non-standard DCs.
vec![ /// Matches the C implementation's `default <dc_id>` config directive.
"127.0.0.1".parse().unwrap(), /// If not set, defaults to 2 (matching Telegram's official `default 2;` in proxy-multi.conf).
"::1".parse().unwrap(), #[serde(default)]
] pub default_dc: Option<u8>,
}
impl Default for ProxyConfig {
fn default() -> Self {
let mut users = HashMap::new();
users.insert("default".to_string(), "00000000000000000000000000000000".to_string());
Self {
port: default_port(),
users,
ad_tag: None,
modes: ProxyModes::default(),
tls_domain: default_tls_domain(),
mask: true,
mask_host: None,
mask_port: default_mask_port(),
prefer_ipv6: false,
fast_mode: true,
use_middle_proxy: false,
user_max_tcp_conns: HashMap::new(),
user_expirations: HashMap::new(),
user_data_quota: HashMap::new(),
replay_check_len: default_replay_check_len(),
ignore_time_skew: false,
client_handshake_timeout: default_handshake_timeout(),
tg_connect_timeout: default_connect_timeout(),
client_keepalive: default_keepalive(),
client_ack_timeout: default_ack_timeout(),
listen_addr_ipv4: default_listen_addr(),
listen_addr_ipv6: Some("::".to_string()),
listen_unix_sock: None,
metrics_port: None,
metrics_whitelist: default_metrics_whitelist(),
fake_cert_len: default_fake_cert_len(),
upstreams: Vec::new(),
listeners: Vec::new(),
show_link: Vec::new(),
}
}
} }
impl ProxyConfig { impl ProxyConfig {
@@ -228,7 +362,7 @@ impl ProxyConfig {
.map_err(|e| ProxyError::Config(e.to_string()))?; .map_err(|e| ProxyError::Config(e.to_string()))?;
// Validate secrets // Validate secrets
for (user, secret) in &config.users { for (user, secret) in &config.access.users {
if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 { if !secret.chars().all(|c| c.is_ascii_hexdigit()) || secret.len() != 32 {
return Err(ProxyError::InvalidSecret { return Err(ProxyError::InvalidSecret {
user: user.clone(), user: user.clone(),
@@ -237,26 +371,31 @@ impl ProxyConfig {
} }
} }
// Default mask_host // Validate tls_domain
if config.mask_host.is_none() { if config.censorship.tls_domain.is_empty() {
config.mask_host = Some(config.tls_domain.clone()); return Err(ProxyError::Config("tls_domain cannot be empty".to_string()));
}
// Default mask_host to tls_domain if not set
if config.censorship.mask_host.is_none() {
config.censorship.mask_host = Some(config.censorship.tls_domain.clone());
} }
// Random fake_cert_len // Random fake_cert_len
use rand::Rng; use rand::Rng;
config.fake_cert_len = rand::thread_rng().gen_range(1024..4096); config.censorship.fake_cert_len = rand::rng().gen_range(1024..4096);
// Migration: Populate listeners if empty // Migration: Populate listeners if empty
if config.listeners.is_empty() { if config.server.listeners.is_empty() {
if let Ok(ipv4) = config.listen_addr_ipv4.parse::<IpAddr>() { if let Ok(ipv4) = config.server.listen_addr_ipv4.parse::<IpAddr>() {
config.listeners.push(ListenerConfig { config.server.listeners.push(ListenerConfig {
ip: ipv4, ip: ipv4,
announce_ip: None, announce_ip: None,
}); });
} }
if let Some(ipv6_str) = &config.listen_addr_ipv6 { if let Some(ipv6_str) = &config.server.listen_addr_ipv6 {
if let Ok(ipv6) = ipv6_str.parse::<IpAddr>() { if let Ok(ipv6) = ipv6_str.parse::<IpAddr>() {
config.listeners.push(ListenerConfig { config.server.listeners.push(ListenerConfig {
ip: ipv6, ip: ipv6,
announce_ip: None, announce_ip: None,
}); });
@@ -277,14 +416,20 @@ impl ProxyConfig {
} }
pub fn validate(&self) -> Result<()> { pub fn validate(&self) -> Result<()> {
if self.users.is_empty() { if self.access.users.is_empty() {
return Err(ProxyError::Config("No users configured".to_string())); return Err(ProxyError::Config("No users configured".to_string()));
} }
if !self.modes.classic && !self.modes.secure && !self.modes.tls { if !self.general.modes.classic && !self.general.modes.secure && !self.general.modes.tls {
return Err(ProxyError::Config("No modes enabled".to_string())); return Err(ProxyError::Config("No modes enabled".to_string()));
} }
if self.censorship.tls_domain.contains(' ') || self.censorship.tls_domain.contains('/') {
return Err(ProxyError::Config(
format!("Invalid tls_domain: '{}'. Must be a valid domain name", self.censorship.tls_domain)
));
}
Ok(()) Ok(())
} }
} }

View File

@@ -1,9 +1,19 @@
//! AES encryption implementations //! AES encryption implementations
//! //!
//! Provides AES-256-CTR and AES-256-CBC modes for MTProto encryption. //! Provides AES-256-CTR and AES-256-CBC modes for MTProto encryption.
//!
//! ## Zeroize policy
//!
//! - `AesCbc` stores raw key/IV bytes and zeroizes them on drop.
//! - `AesCtr` wraps an opaque `Aes256Ctr` cipher from the `ctr` crate.
//! The expanded key schedule lives inside that type and cannot be
//! zeroized from outside. Callers that hold raw key material (e.g.
//! `HandshakeSuccess`, `ObfuscationParams`) are responsible for
//! zeroizing their own copies.
use aes::Aes256; use aes::Aes256;
use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}}; use ctr::{Ctr128BE, cipher::{KeyIvInit, StreamCipher}};
use zeroize::Zeroize;
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
type Aes256Ctr = Ctr128BE<Aes256>; type Aes256Ctr = Ctr128BE<Aes256>;
@@ -12,7 +22,12 @@ type Aes256Ctr = Ctr128BE<Aes256>;
/// AES-256-CTR encryptor/decryptor /// AES-256-CTR encryptor/decryptor
/// ///
/// CTR mode is symmetric - encryption and decryption are the same operation. /// CTR mode is symmetric encryption and decryption are the same operation.
///
/// **Zeroize note:** The inner `Aes256Ctr` cipher state (expanded key schedule
/// + counter) is opaque and cannot be zeroized. If you need to protect key
/// material, zeroize the `[u8; 32]` key and `u128` IV at the call site
/// before dropping them.
pub struct AesCtr { pub struct AesCtr {
cipher: Aes256Ctr, cipher: Aes256Ctr,
} }
@@ -62,14 +77,23 @@ impl AesCtr {
/// AES-256-CBC cipher with proper chaining /// AES-256-CBC cipher with proper chaining
/// ///
/// Unlike CTR mode, CBC is NOT symmetric - encryption and decryption /// Unlike CTR mode, CBC is NOT symmetric encryption and decryption
/// are different operations. This implementation handles CBC chaining /// are different operations. This implementation handles CBC chaining
/// correctly across multiple blocks. /// correctly across multiple blocks.
///
/// Key and IV are zeroized on drop.
pub struct AesCbc { pub struct AesCbc {
key: [u8; 32], key: [u8; 32],
iv: [u8; 16], iv: [u8; 16],
} }
impl Drop for AesCbc {
fn drop(&mut self) {
self.key.zeroize();
self.iv.zeroize();
}
}
impl AesCbc { impl AesCbc {
/// AES block size /// AES block size
const BLOCK_SIZE: usize = 16; const BLOCK_SIZE: usize = 16;
@@ -141,17 +165,9 @@ impl AesCbc {
for chunk in data.chunks(Self::BLOCK_SIZE) { for chunk in data.chunks(Self::BLOCK_SIZE) {
let plaintext: [u8; 16] = chunk.try_into().unwrap(); let plaintext: [u8; 16] = chunk.try_into().unwrap();
// XOR plaintext with previous ciphertext (or IV for first block)
let xored = Self::xor_blocks(&plaintext, &prev_ciphertext); let xored = Self::xor_blocks(&plaintext, &prev_ciphertext);
// Encrypt the XORed block
let ciphertext = self.encrypt_block(&xored, &key_schedule); let ciphertext = self.encrypt_block(&xored, &key_schedule);
// Save for next iteration
prev_ciphertext = ciphertext; prev_ciphertext = ciphertext;
// Append to result
result.extend_from_slice(&ciphertext); result.extend_from_slice(&ciphertext);
} }
@@ -180,17 +196,9 @@ impl AesCbc {
for chunk in data.chunks(Self::BLOCK_SIZE) { for chunk in data.chunks(Self::BLOCK_SIZE) {
let ciphertext: [u8; 16] = chunk.try_into().unwrap(); let ciphertext: [u8; 16] = chunk.try_into().unwrap();
// Decrypt the block
let decrypted = self.decrypt_block(&ciphertext, &key_schedule); let decrypted = self.decrypt_block(&ciphertext, &key_schedule);
// XOR with previous ciphertext (or IV for first block)
let plaintext = Self::xor_blocks(&decrypted, &prev_ciphertext); let plaintext = Self::xor_blocks(&decrypted, &prev_ciphertext);
// Save current ciphertext for next iteration
prev_ciphertext = ciphertext; prev_ciphertext = ciphertext;
// Append to result
result.extend_from_slice(&plaintext); result.extend_from_slice(&plaintext);
} }
@@ -217,16 +225,13 @@ impl AesCbc {
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) { for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
let block = &mut data[i..i + Self::BLOCK_SIZE]; let block = &mut data[i..i + Self::BLOCK_SIZE];
// XOR with previous ciphertext
for j in 0..Self::BLOCK_SIZE { for j in 0..Self::BLOCK_SIZE {
block[j] ^= prev_ciphertext[j]; block[j] ^= prev_ciphertext[j];
} }
// Encrypt in-place
let block_array: &mut [u8; 16] = block.try_into().unwrap(); let block_array: &mut [u8; 16] = block.try_into().unwrap();
*block_array = self.encrypt_block(block_array, &key_schedule); *block_array = self.encrypt_block(block_array, &key_schedule);
// Save for next iteration
prev_ciphertext = *block_array; prev_ciphertext = *block_array;
} }
@@ -248,26 +253,20 @@ impl AesCbc {
use aes::cipher::KeyInit; use aes::cipher::KeyInit;
let key_schedule = aes::Aes256::new((&self.key).into()); let key_schedule = aes::Aes256::new((&self.key).into());
// For in-place decryption, we need to save ciphertext blocks
// before we overwrite them
let mut prev_ciphertext = self.iv; let mut prev_ciphertext = self.iv;
for i in (0..data.len()).step_by(Self::BLOCK_SIZE) { for i in (0..data.len()).step_by(Self::BLOCK_SIZE) {
let block = &mut data[i..i + Self::BLOCK_SIZE]; let block = &mut data[i..i + Self::BLOCK_SIZE];
// Save current ciphertext before modifying
let current_ciphertext: [u8; 16] = block.try_into().unwrap(); let current_ciphertext: [u8; 16] = block.try_into().unwrap();
// Decrypt in-place
let block_array: &mut [u8; 16] = block.try_into().unwrap(); let block_array: &mut [u8; 16] = block.try_into().unwrap();
*block_array = self.decrypt_block(block_array, &key_schedule); *block_array = self.decrypt_block(block_array, &key_schedule);
// XOR with previous ciphertext
for j in 0..Self::BLOCK_SIZE { for j in 0..Self::BLOCK_SIZE {
block[j] ^= prev_ciphertext[j]; block[j] ^= prev_ciphertext[j];
} }
// Save for next iteration
prev_ciphertext = current_ciphertext; prev_ciphertext = current_ciphertext;
} }
@@ -347,10 +346,8 @@ mod tests {
let mut cipher = AesCtr::new(&key, iv); let mut cipher = AesCtr::new(&key, iv);
cipher.apply(&mut data); cipher.apply(&mut data);
// Encrypted should be different
assert_ne!(&data[..], original); assert_ne!(&data[..], original);
// Decrypt with fresh cipher
let mut cipher = AesCtr::new(&key, iv); let mut cipher = AesCtr::new(&key, iv);
cipher.apply(&mut data); cipher.apply(&mut data);
@@ -364,7 +361,7 @@ mod tests {
let key = [0u8; 32]; let key = [0u8; 32];
let iv = [0u8; 16]; let iv = [0u8; 16];
let original = [0u8; 32]; // 2 blocks let original = [0u8; 32];
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
let encrypted = cipher.encrypt(&original).unwrap(); let encrypted = cipher.encrypt(&original).unwrap();
@@ -375,31 +372,25 @@ mod tests {
#[test] #[test]
fn test_aes_cbc_chaining_works() { fn test_aes_cbc_chaining_works() {
// This is the key test - verify CBC chaining is correct
let key = [0x42u8; 32]; let key = [0x42u8; 32];
let iv = [0x00u8; 16]; let iv = [0x00u8; 16];
// Two IDENTICAL plaintext blocks
let plaintext = [0xAAu8; 32]; let plaintext = [0xAAu8; 32];
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
let ciphertext = cipher.encrypt(&plaintext).unwrap(); let ciphertext = cipher.encrypt(&plaintext).unwrap();
// With proper CBC, identical plaintext blocks produce DIFFERENT ciphertext
let block1 = &ciphertext[0..16]; let block1 = &ciphertext[0..16];
let block2 = &ciphertext[16..32]; let block2 = &ciphertext[16..32];
assert_ne!( assert_ne!(
block1, block2, block1, block2,
"CBC chaining broken: identical plaintext blocks produced identical ciphertext. \ "CBC chaining broken: identical plaintext blocks produced identical ciphertext"
This indicates ECB mode, not CBC!"
); );
} }
#[test] #[test]
fn test_aes_cbc_known_vector() { fn test_aes_cbc_known_vector() {
// Test with known NIST test vector
// AES-256-CBC with zero key and zero IV
let key = [0u8; 32]; let key = [0u8; 32];
let iv = [0u8; 16]; let iv = [0u8; 16];
let plaintext = [0u8; 16]; let plaintext = [0u8; 16];
@@ -407,11 +398,9 @@ mod tests {
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
let ciphertext = cipher.encrypt(&plaintext).unwrap(); let ciphertext = cipher.encrypt(&plaintext).unwrap();
// Decrypt and verify roundtrip
let decrypted = cipher.decrypt(&ciphertext).unwrap(); let decrypted = cipher.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice()); assert_eq!(plaintext.as_slice(), decrypted.as_slice());
// Ciphertext should not be all zeros
assert_ne!(ciphertext.as_slice(), plaintext.as_slice()); assert_ne!(ciphertext.as_slice(), plaintext.as_slice());
} }
@@ -420,7 +409,6 @@ mod tests {
let key = [0x12u8; 32]; let key = [0x12u8; 32];
let iv = [0x34u8; 16]; let iv = [0x34u8; 16];
// 5 blocks = 80 bytes
let plaintext: Vec<u8> = (0..80).collect(); let plaintext: Vec<u8> = (0..80).collect();
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
@@ -435,7 +423,7 @@ mod tests {
let key = [0x12u8; 32]; let key = [0x12u8; 32];
let iv = [0x34u8; 16]; let iv = [0x34u8; 16];
let original = [0x56u8; 48]; // 3 blocks let original = [0x56u8; 48];
let mut buffer = original; let mut buffer = original;
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
@@ -462,41 +450,33 @@ mod tests {
fn test_aes_cbc_unaligned_error() { fn test_aes_cbc_unaligned_error() {
let cipher = AesCbc::new([0u8; 32], [0u8; 16]); let cipher = AesCbc::new([0u8; 32], [0u8; 16]);
// 15 bytes - not aligned to block size
let result = cipher.encrypt(&[0u8; 15]); let result = cipher.encrypt(&[0u8; 15]);
assert!(result.is_err()); assert!(result.is_err());
// 17 bytes - not aligned
let result = cipher.encrypt(&[0u8; 17]); let result = cipher.encrypt(&[0u8; 17]);
assert!(result.is_err()); assert!(result.is_err());
} }
#[test] #[test]
fn test_aes_cbc_avalanche_effect() { fn test_aes_cbc_avalanche_effect() {
// Changing one bit in plaintext should change entire ciphertext block
// and all subsequent blocks (due to chaining)
let key = [0xAB; 32]; let key = [0xAB; 32];
let iv = [0xCD; 16]; let iv = [0xCD; 16];
let mut plaintext1 = [0u8; 32]; let plaintext1 = [0u8; 32];
let mut plaintext2 = [0u8; 32]; let mut plaintext2 = [0u8; 32];
plaintext2[0] = 0x01; // Single bit difference in first block plaintext2[0] = 0x01;
let cipher = AesCbc::new(key, iv); let cipher = AesCbc::new(key, iv);
let ciphertext1 = cipher.encrypt(&plaintext1).unwrap(); let ciphertext1 = cipher.encrypt(&plaintext1).unwrap();
let ciphertext2 = cipher.encrypt(&plaintext2).unwrap(); let ciphertext2 = cipher.encrypt(&plaintext2).unwrap();
// First blocks should be different
assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]); assert_ne!(&ciphertext1[0..16], &ciphertext2[0..16]);
// Second blocks should ALSO be different (chaining effect)
assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]); assert_ne!(&ciphertext1[16..32], &ciphertext2[16..32]);
} }
#[test] #[test]
fn test_aes_cbc_iv_matters() { fn test_aes_cbc_iv_matters() {
// Same plaintext with different IVs should produce different ciphertext
let key = [0x55; 32]; let key = [0x55; 32];
let plaintext = [0x77u8; 16]; let plaintext = [0x77u8; 16];
@@ -511,7 +491,6 @@ mod tests {
#[test] #[test]
fn test_aes_cbc_deterministic() { fn test_aes_cbc_deterministic() {
// Same key, IV, plaintext should always produce same ciphertext
let key = [0x99; 32]; let key = [0x99; 32];
let iv = [0x88; 16]; let iv = [0x88; 16];
let plaintext = [0x77u8; 32]; let plaintext = [0x77u8; 32];
@@ -524,6 +503,23 @@ mod tests {
assert_eq!(ciphertext1, ciphertext2); assert_eq!(ciphertext1, ciphertext2);
} }
// ============= Zeroize Tests =============
#[test]
fn test_aes_cbc_zeroize_on_drop() {
let key = [0xAA; 32];
let iv = [0xBB; 16];
let cipher = AesCbc::new(key, iv);
// Verify key/iv are set
assert_eq!(cipher.key, [0xAA; 32]);
assert_eq!(cipher.iv, [0xBB; 16]);
drop(cipher);
// After drop, key/iv are zeroized (can't observe directly,
// but the Drop impl runs without panic)
}
// ============= Error Handling Tests ============= // ============= Error Handling Tests =============
#[test] #[test]

View File

@@ -1,3 +1,16 @@
//! Cryptographic hash functions
//!
//! ## Protocol-required algorithms
//!
//! This module exposes MD5 and SHA-1 alongside SHA-256. These weaker
//! hash functions are **required by the Telegram Middle Proxy protocol**
//! (`derive_middleproxy_keys`) and cannot be replaced without breaking
//! compatibility. They are NOT used for any security-sensitive purpose
//! outside of that specific key derivation scheme mandated by Telegram.
//!
//! Static analysis tools (CodeQL, cargo-audit) may flag them — the
//! usages are intentional and protocol-mandated.
use hmac::{Hmac, Mac}; use hmac::{Hmac, Mac};
use sha2::Sha256; use sha2::Sha256;
use md5::Md5; use md5::Md5;
@@ -21,14 +34,16 @@ pub fn sha256_hmac(key: &[u8], data: &[u8]) -> [u8; 32] {
mac.finalize().into_bytes().into() mac.finalize().into_bytes().into()
} }
/// SHA-1 /// SHA-1 — **protocol-required** by Telegram Middle Proxy key derivation.
/// Not used for general-purpose hashing.
pub fn sha1(data: &[u8]) -> [u8; 20] { pub fn sha1(data: &[u8]) -> [u8; 20] {
let mut hasher = Sha1::new(); let mut hasher = Sha1::new();
hasher.update(data); hasher.update(data);
hasher.finalize().into() hasher.finalize().into()
} }
/// MD5 /// MD5 — **protocol-required** by Telegram Middle Proxy key derivation.
/// Not used for general-purpose hashing.
pub fn md5(data: &[u8]) -> [u8; 16] { pub fn md5(data: &[u8]) -> [u8; 16] {
let mut hasher = Md5::new(); let mut hasher = Md5::new();
hasher.update(data); hasher.update(data);
@@ -40,7 +55,11 @@ pub fn crc32(data: &[u8]) -> u32 {
crc32fast::hash(data) crc32fast::hash(data)
} }
/// Middle Proxy Keygen /// Middle Proxy key derivation
///
/// Uses MD5 + SHA-1 as mandated by the Telegram Middle Proxy protocol.
/// These algorithms are NOT replaceable here — changing them would break
/// interoperability with Telegram's middle proxy infrastructure.
pub fn derive_middleproxy_keys( pub fn derive_middleproxy_keys(
nonce_srv: &[u8; 16], nonce_srv: &[u8; 16],
nonce_clt: &[u8; 16], nonce_clt: &[u8; 16],

View File

@@ -6,4 +6,4 @@ pub mod random;
pub use aes::{AesCtr, AesCbc}; pub use aes::{AesCtr, AesCbc};
pub use hash::{sha256, sha256_hmac, sha1, md5, crc32}; pub use hash::{sha256, sha256_hmac, sha1, md5, crc32};
pub use random::{SecureRandom, SECURE_RANDOM}; pub use random::SecureRandom;

View File

@@ -3,11 +3,8 @@
use rand::{Rng, RngCore, SeedableRng}; use rand::{Rng, RngCore, SeedableRng};
use rand::rngs::StdRng; use rand::rngs::StdRng;
use parking_lot::Mutex; use parking_lot::Mutex;
use zeroize::Zeroize;
use crate::crypto::AesCtr; use crate::crypto::AesCtr;
use once_cell::sync::Lazy;
/// Global secure random instance
pub static SECURE_RANDOM: Lazy<SecureRandom> = Lazy::new(SecureRandom::new);
/// Cryptographically secure PRNG with AES-CTR /// Cryptographically secure PRNG with AES-CTR
pub struct SecureRandom { pub struct SecureRandom {
@@ -20,18 +17,30 @@ struct SecureRandomInner {
buffer: Vec<u8>, buffer: Vec<u8>,
} }
impl Drop for SecureRandomInner {
fn drop(&mut self) {
self.buffer.zeroize();
}
}
impl SecureRandom { impl SecureRandom {
pub fn new() -> Self { pub fn new() -> Self {
let mut rng = StdRng::from_entropy(); let mut seed_source = rand::rng();
let mut rng = StdRng::from_rng(&mut seed_source);
let mut key = [0u8; 32]; let mut key = [0u8; 32];
rng.fill_bytes(&mut key); rng.fill_bytes(&mut key);
let iv: u128 = rng.gen(); let iv: u128 = rng.random();
let cipher = AesCtr::new(&key, iv);
// Zeroize local key copy — cipher already consumed it
key.zeroize();
Self { Self {
inner: Mutex::new(SecureRandomInner { inner: Mutex::new(SecureRandomInner {
rng, rng,
cipher: AesCtr::new(&key, iv), cipher,
buffer: Vec::with_capacity(1024), buffer: Vec::with_capacity(1024),
}), }),
} }
@@ -78,7 +87,6 @@ impl SecureRandom {
result |= (b as u64) << (i * 8); result |= (b as u64) << (i * 8);
} }
// Mask extra bits
if k < 64 { if k < 64 {
result &= (1u64 << k) - 1; result &= (1u64 << k) - 1;
} }
@@ -107,13 +115,13 @@ impl SecureRandom {
/// Generate random u32 /// Generate random u32
pub fn u32(&self) -> u32 { pub fn u32(&self) -> u32 {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
inner.rng.gen() inner.rng.random()
} }
/// Generate random u64 /// Generate random u64
pub fn u64(&self) -> u64 { pub fn u64(&self) -> u64 {
let mut inner = self.inner.lock(); let mut inner = self.inner.lock();
inner.rng.gen() inner.rng.random()
} }
} }
@@ -162,12 +170,10 @@ mod tests {
fn test_bits() { fn test_bits() {
let rng = SecureRandom::new(); let rng = SecureRandom::new();
// Single bit should be 0 or 1
for _ in 0..100 { for _ in 0..100 {
assert!(rng.bits(1) <= 1); assert!(rng.bits(1) <= 1);
} }
// 8 bits should be 0-255
for _ in 0..100 { for _ in 0..100 {
assert!(rng.bits(8) <= 255); assert!(rng.bits(8) <= 255);
} }
@@ -185,10 +191,8 @@ mod tests {
} }
} }
// Should have seen all items
assert_eq!(seen.len(), 5); assert_eq!(seen.len(), 5);
// Empty slice should return None
let empty: Vec<i32> = vec![]; let empty: Vec<i32> = vec![];
assert!(rng.choose(&empty).is_none()); assert!(rng.choose(&empty).is_none());
} }
@@ -201,12 +205,10 @@ mod tests {
let mut shuffled = original.clone(); let mut shuffled = original.clone();
rng.shuffle(&mut shuffled); rng.shuffle(&mut shuffled);
// Should contain same elements
let mut sorted = shuffled.clone(); let mut sorted = shuffled.clone();
sorted.sort(); sorted.sort();
assert_eq!(sorted, original); assert_eq!(sorted, original);
// Should be different order (with very high probability)
assert_ne!(shuffled, original); assert_ne!(shuffled, original);
} }
} }

View File

@@ -118,16 +118,13 @@ pub trait Recoverable {
impl Recoverable for StreamError { impl Recoverable for StreamError {
fn is_recoverable(&self) -> bool { fn is_recoverable(&self) -> bool {
match self { match self {
// Partial operations can be retried
Self::PartialRead { .. } | Self::PartialWrite { .. } => true, Self::PartialRead { .. } | Self::PartialWrite { .. } => true,
// I/O errors depend on kind
Self::Io(e) => matches!( Self::Io(e) => matches!(
e.kind(), e.kind(),
std::io::ErrorKind::WouldBlock std::io::ErrorKind::WouldBlock
| std::io::ErrorKind::Interrupted | std::io::ErrorKind::Interrupted
| std::io::ErrorKind::TimedOut | std::io::ErrorKind::TimedOut
), ),
// These are not recoverable
Self::Poisoned { .. } Self::Poisoned { .. }
| Self::BufferOverflow { .. } | Self::BufferOverflow { .. }
| Self::InvalidFrame { .. } | Self::InvalidFrame { .. }
@@ -137,13 +134,9 @@ impl Recoverable for StreamError {
fn can_continue(&self) -> bool { fn can_continue(&self) -> bool {
match self { match self {
// Poisoned stream cannot be used
Self::Poisoned { .. } => false, Self::Poisoned { .. } => false,
// EOF means stream is done
Self::UnexpectedEof => false, Self::UnexpectedEof => false,
// Buffer overflow is fatal
Self::BufferOverflow { .. } => false, Self::BufferOverflow { .. } => false,
// Others might allow continuation
_ => true, _ => true,
} }
} }
@@ -297,16 +290,16 @@ pub type StreamResult<T> = std::result::Result<T, StreamError>;
/// Result with optional bad client handling /// Result with optional bad client handling
#[derive(Debug)] #[derive(Debug)]
pub enum HandshakeResult<T> { pub enum HandshakeResult<T, R, W> {
/// Handshake succeeded /// Handshake succeeded
Success(T), Success(T),
/// Client failed validation, needs masking /// Client failed validation, needs masking. Returns ownership of streams.
BadClient, BadClient { reader: R, writer: W },
/// Error occurred /// Error occurred
Error(ProxyError), Error(ProxyError),
} }
impl<T> HandshakeResult<T> { impl<T, R, W> HandshakeResult<T, R, W> {
/// Check if successful /// Check if successful
pub fn is_success(&self) -> bool { pub fn is_success(&self) -> bool {
matches!(self, HandshakeResult::Success(_)) matches!(self, HandshakeResult::Success(_))
@@ -314,49 +307,32 @@ impl<T> HandshakeResult<T> {
/// Check if bad client /// Check if bad client
pub fn is_bad_client(&self) -> bool { pub fn is_bad_client(&self) -> bool {
matches!(self, HandshakeResult::BadClient) matches!(self, HandshakeResult::BadClient { .. })
}
/// Convert to Result, treating BadClient as error
pub fn into_result(self) -> Result<T> {
match self {
HandshakeResult::Success(v) => Ok(v),
HandshakeResult::BadClient => Err(ProxyError::InvalidHandshake("Bad client".into())),
HandshakeResult::Error(e) => Err(e),
}
} }
/// Map the success value /// Map the success value
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> HandshakeResult<U> { pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> HandshakeResult<U, R, W> {
match self { match self {
HandshakeResult::Success(v) => HandshakeResult::Success(f(v)), HandshakeResult::Success(v) => HandshakeResult::Success(f(v)),
HandshakeResult::BadClient => HandshakeResult::BadClient, HandshakeResult::BadClient { reader, writer } => HandshakeResult::BadClient { reader, writer },
HandshakeResult::Error(e) => HandshakeResult::Error(e), HandshakeResult::Error(e) => HandshakeResult::Error(e),
} }
} }
/// Convert success to Option
pub fn ok(self) -> Option<T> {
match self {
HandshakeResult::Success(v) => Some(v),
_ => None,
}
}
} }
impl<T> From<ProxyError> for HandshakeResult<T> { impl<T, R, W> From<ProxyError> for HandshakeResult<T, R, W> {
fn from(err: ProxyError) -> Self { fn from(err: ProxyError) -> Self {
HandshakeResult::Error(err) HandshakeResult::Error(err)
} }
} }
impl<T> From<std::io::Error> for HandshakeResult<T> { impl<T, R, W> From<std::io::Error> for HandshakeResult<T, R, W> {
fn from(err: std::io::Error) -> Self { fn from(err: std::io::Error) -> Self {
HandshakeResult::Error(ProxyError::Io(err)) HandshakeResult::Error(ProxyError::Io(err))
} }
} }
impl<T> From<StreamError> for HandshakeResult<T> { impl<T, R, W> From<StreamError> for HandshakeResult<T, R, W> {
fn from(err: StreamError) -> Self { fn from(err: StreamError) -> Self {
HandshakeResult::Error(ProxyError::Stream(err)) HandshakeResult::Error(ProxyError::Stream(err))
} }
@@ -400,18 +376,18 @@ mod tests {
#[test] #[test]
fn test_handshake_result() { fn test_handshake_result() {
let success: HandshakeResult<i32> = HandshakeResult::Success(42); let success: HandshakeResult<i32, (), ()> = HandshakeResult::Success(42);
assert!(success.is_success()); assert!(success.is_success());
assert!(!success.is_bad_client()); assert!(!success.is_bad_client());
let bad: HandshakeResult<i32> = HandshakeResult::BadClient; let bad: HandshakeResult<i32, (), ()> = HandshakeResult::BadClient { reader: (), writer: () };
assert!(!bad.is_success()); assert!(!bad.is_success());
assert!(bad.is_bad_client()); assert!(bad.is_bad_client());
} }
#[test] #[test]
fn test_handshake_result_map() { fn test_handshake_result_map() {
let success: HandshakeResult<i32> = HandshakeResult::Success(42); let success: HandshakeResult<i32, (), ()> = HandshakeResult::Success(42);
let mapped = success.map(|x| x * 2); let mapped = success.map(|x| x * 2);
match mapped { match mapped {

View File

@@ -5,9 +5,11 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::signal; use tokio::signal;
use tracing::{info, error, warn}; use tokio::sync::Semaphore;
use tracing_subscriber::{fmt, EnvFilter}; use tracing::{info, error, warn, debug};
use tracing_subscriber::{fmt, EnvFilter, reload, prelude::*};
mod cli;
mod config; mod config;
mod crypto; mod crypto;
mod error; mod error;
@@ -18,63 +20,178 @@ mod stream;
mod transport; mod transport;
mod util; mod util;
use crate::config::ProxyConfig; use crate::config::{ProxyConfig, LogLevel};
use crate::proxy::ClientHandler; use crate::proxy::ClientHandler;
use crate::stats::{Stats, ReplayChecker}; use crate::stats::{Stats, ReplayChecker};
use crate::crypto::SecureRandom;
use crate::transport::{create_listener, ListenOptions, UpstreamManager}; use crate::transport::{create_listener, ListenOptions, UpstreamManager};
use crate::util::ip::detect_ip; use crate::util::ip::detect_ip;
use crate::stream::BufferPool;
fn parse_cli() -> (String, bool, Option<String>) {
let mut config_path = "config.toml".to_string();
let mut silent = false;
let mut log_level: Option<String> = None;
let args: Vec<String> = std::env::args().skip(1).collect();
// Check for --init first (handled before tokio)
if let Some(init_opts) = cli::parse_init_args(&args) {
if let Err(e) = cli::run_init(init_opts) {
eprintln!("[telemt] Init failed: {}", e);
std::process::exit(1);
}
std::process::exit(0);
}
let mut i = 0;
while i < args.len() {
match args[i].as_str() {
"--silent" | "-s" => { silent = true; }
"--log-level" => {
i += 1;
if i < args.len() { log_level = Some(args[i].clone()); }
}
s if s.starts_with("--log-level=") => {
log_level = Some(s.trim_start_matches("--log-level=").to_string());
}
"--help" | "-h" => {
eprintln!("Usage: telemt [config.toml] [OPTIONS]");
eprintln!();
eprintln!("Options:");
eprintln!(" --silent, -s Suppress info logs");
eprintln!(" --log-level <LEVEL> debug|verbose|normal|silent");
eprintln!(" --help, -h Show this help");
eprintln!();
eprintln!("Setup (fire-and-forget):");
eprintln!(" --init Generate config, install systemd service, start");
eprintln!(" --port <PORT> Listen port (default: 443)");
eprintln!(" --domain <DOMAIN> TLS domain for masking (default: www.google.com)");
eprintln!(" --secret <HEX> 32-char hex secret (auto-generated if omitted)");
eprintln!(" --user <NAME> Username (default: user)");
eprintln!(" --config-dir <DIR> Config directory (default: /etc/telemt)");
eprintln!(" --no-start Don't start the service after install");
std::process::exit(0);
}
s if !s.starts_with('-') => { config_path = s.to_string(); }
other => { eprintln!("Unknown option: {}", other); }
}
i += 1;
}
(config_path, silent, log_level)
}
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Initialize logging let (config_path, cli_silent, cli_log_level) = parse_cli();
fmt()
.with_env_filter(EnvFilter::from_default_env().add_directive("info".parse().unwrap()))
.init();
// Load config
let config_path = std::env::args().nth(1).unwrap_or_else(|| "config.toml".to_string());
let config = match ProxyConfig::load(&config_path) { let config = match ProxyConfig::load(&config_path) {
Ok(c) => c, Ok(c) => c,
Err(e) => { Err(e) => {
// If config doesn't exist, try to create default
if std::path::Path::new(&config_path).exists() { if std::path::Path::new(&config_path).exists() {
error!("Failed to load config: {}", e); eprintln!("[telemt] Error: {}", e);
std::process::exit(1); std::process::exit(1);
} else { } else {
let default = ProxyConfig::default(); let default = ProxyConfig::default();
let toml = toml::to_string_pretty(&default).unwrap(); std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap();
std::fs::write(&config_path, toml).unwrap(); eprintln!("[telemt] Created default config at {}", config_path);
info!("Created default config at {}", config_path);
default default
} }
} }
}; };
config.validate()?; if let Err(e) = config.validate() {
eprintln!("[telemt] Invalid config: {}", e);
std::process::exit(1);
}
let has_rust_log = std::env::var("RUST_LOG").is_ok();
let effective_log_level = if cli_silent {
LogLevel::Silent
} else if let Some(ref s) = cli_log_level {
LogLevel::from_str_loose(s)
} else {
config.general.log_level.clone()
};
// Start with INFO so startup messages are always visible,
// then switch to user-configured level after startup
let (filter_layer, filter_handle) = reload::Layer::new(EnvFilter::new("info"));
tracing_subscriber::registry()
.with(filter_layer)
.with(fmt::Layer::default())
.init();
info!("Telemt MTProxy v{}", env!("CARGO_PKG_VERSION"));
info!("Log level: {}", effective_log_level);
info!("Modes: classic={} secure={} tls={}",
config.general.modes.classic,
config.general.modes.secure,
config.general.modes.tls);
info!("TLS domain: {}", config.censorship.tls_domain);
info!("Mask: {} -> {}:{}",
config.censorship.mask,
config.censorship.mask_host.as_deref().unwrap_or(&config.censorship.tls_domain),
config.censorship.mask_port);
if config.censorship.tls_domain == "www.google.com" {
warn!("Using default tls_domain. Consider setting a custom domain.");
}
let prefer_ipv6 = config.general.prefer_ipv6;
let config = Arc::new(config); let config = Arc::new(config);
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
let rng = Arc::new(SecureRandom::new());
// CHANGED: Initialize global ReplayChecker here instead of per-connection let replay_checker = Arc::new(ReplayChecker::new(
let replay_checker = Arc::new(ReplayChecker::new(config.replay_check_len)); config.access.replay_check_len,
Duration::from_secs(config.access.replay_window_secs),
));
// Initialize Upstream Manager
let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone())); let upstream_manager = Arc::new(UpstreamManager::new(config.upstreams.clone()));
let buffer_pool = Arc::new(BufferPool::with_config(16 * 1024, 4096));
// Start Health Checks // Connection concurrency limit — prevents OOM under SYN flood / connection storm.
// 10000 is generous; each connection uses ~64KB (2x 16KB relay buffers + overhead).
// 10000 connections ≈ 640MB peak memory.
let max_connections = Arc::new(Semaphore::new(10_000));
// Startup DC ping
info!("=== Telegram DC Connectivity ===");
let ping_results = upstream_manager.ping_all_dcs(prefer_ipv6).await;
for upstream_result in &ping_results {
info!(" via {}", upstream_result.upstream_name);
for dc in &upstream_result.results {
match (&dc.rtt_ms, &dc.error) {
(Some(rtt), _) => {
info!(" DC{} ({:>21}): {:.0}ms", dc.dc_idx, dc.dc_addr, rtt);
}
(None, Some(err)) => {
info!(" DC{} ({:>21}): FAIL ({})", dc.dc_idx, dc.dc_addr, err);
}
_ => {
info!(" DC{} ({:>21}): FAIL", dc.dc_idx, dc.dc_addr);
}
}
}
}
info!("================================");
// Background tasks
let um_clone = upstream_manager.clone(); let um_clone = upstream_manager.clone();
tokio::spawn(async move { tokio::spawn(async move { um_clone.run_health_checks(prefer_ipv6).await; });
um_clone.run_health_checks().await;
}); let rc_clone = replay_checker.clone();
tokio::spawn(async move { rc_clone.run_periodic_cleanup().await; });
// Detect public IP if needed (once at startup)
let detected_ip = detect_ip().await; let detected_ip = detect_ip().await;
debug!("Detected IPs: v4={:?} v6={:?}", detected_ip.ipv4, detected_ip.ipv6);
// Start Listeners
let mut listeners = Vec::new(); let mut listeners = Vec::new();
for listener_conf in &config.listeners { for listener_conf in &config.server.listeners {
let addr = SocketAddr::new(listener_conf.ip, config.port); let addr = SocketAddr::new(listener_conf.ip, config.server.port);
let options = ListenOptions { let options = ListenOptions {
ipv6_only: listener_conf.ip.is_ipv6(), ipv6_only: listener_conf.ip.is_ipv6(),
..Default::default() ..Default::default()
@@ -85,14 +202,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::from_std(socket.into())?; let listener = TcpListener::from_std(socket.into())?;
info!("Listening on {}", addr); info!("Listening on {}", addr);
// Determine public IP for tg:// links
// 1. Use explicit announce_ip if set
// 2. If listening on 0.0.0.0 or ::, use detected public IP
// 3. Otherwise use the bind IP
let public_ip = if let Some(ip) = listener_conf.announce_ip { let public_ip = if let Some(ip) = listener_conf.announce_ip {
ip ip
} else if listener_conf.ip.is_unspecified() { } else if listener_conf.ip.is_unspecified() {
// Try to use detected IP of the same family
if listener_conf.ip.is_ipv4() { if listener_conf.ip.is_ipv4() {
detected_ip.ipv4.unwrap_or(listener_conf.ip) detected_ip.ipv4.unwrap_or(listener_conf.ip)
} else { } else {
@@ -102,36 +214,29 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
listener_conf.ip listener_conf.ip
}; };
// Show links for configured users
if !config.show_link.is_empty() { if !config.show_link.is_empty() {
info!("--- Proxy Links for {} ---", public_ip); info!("--- Proxy Links ({}) ---", public_ip);
for user_name in &config.show_link { for user_name in &config.show_link {
if let Some(secret) = config.users.get(user_name) { if let Some(secret) = config.access.users.get(user_name) {
info!("User: {}", user_name); info!("User: {}", user_name);
if config.general.modes.classic {
// Classic
if config.modes.classic {
info!(" Classic: tg://proxy?server={}&port={}&secret={}", info!(" Classic: tg://proxy?server={}&port={}&secret={}",
public_ip, config.port, secret); public_ip, config.server.port, secret);
} }
if config.general.modes.secure {
// DD (Secure)
if config.modes.secure {
info!(" DD: tg://proxy?server={}&port={}&secret=dd{}", info!(" DD: tg://proxy?server={}&port={}&secret=dd{}",
public_ip, config.port, secret); public_ip, config.server.port, secret);
} }
if config.general.modes.tls {
// EE-TLS (FakeTLS) let domain_hex = hex::encode(&config.censorship.tls_domain);
if config.modes.tls {
let domain_hex = hex::encode(&config.tls_domain);
info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}", info!(" EE-TLS: tg://proxy?server={}&port={}&secret=ee{}{}",
public_ip, config.port, secret, domain_hex); public_ip, config.server.port, secret, domain_hex);
} }
} else { } else {
warn!("User '{}' specified in show_link not found in users list", user_name); warn!("User '{}' in show_link not found", user_name);
} }
} }
info!("-----------------------------------"); info!("------------------------");
} }
listeners.push(listener); listeners.push(listener);
@@ -143,16 +248,25 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
} }
if listeners.is_empty() { if listeners.is_empty() {
error!("No listeners could be started. Exiting."); error!("No listeners. Exiting.");
std::process::exit(1); std::process::exit(1);
} }
// Accept loop // Switch to user-configured log level after startup
let runtime_filter = if has_rust_log {
EnvFilter::from_default_env()
} else {
EnvFilter::new(effective_log_level.to_filter_str())
};
filter_handle.reload(runtime_filter).expect("Failed to switch log filter");
for listener in listeners { for listener in listeners {
let config = config.clone(); let config = config.clone();
let stats = stats.clone(); let stats = stats.clone();
let upstream_manager = upstream_manager.clone(); let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone(); let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
@@ -162,18 +276,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let stats = stats.clone(); let stats = stats.clone();
let upstream_manager = upstream_manager.clone(); let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone(); let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = ClientHandler::new( if let Err(e) = ClientHandler::new(
stream, stream, peer_addr, config, stats,
peer_addr, upstream_manager, replay_checker, buffer_pool, rng
config,
stats,
upstream_manager,
replay_checker // Pass global checker
).run().await { ).run().await {
// Log only relevant errors debug!(peer = %peer_addr, error = %e, "Connection error");
// debug!("Connection error: {}", e);
} }
}); });
} }
@@ -186,7 +297,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}); });
} }
// Wait for signal
match signal::ctrl_c().await { match signal::ctrl_c().await {
Ok(()) => info!("Shutting down..."), Ok(()) => info!("Shutting down..."),
Err(e) => error!("Signal error: {}", e), Err(e) => error!("Signal error: {}", e),

View File

@@ -1,13 +1,13 @@
//! Protocol constants and datacenter addresses //! Protocol constants and datacenter addresses
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use once_cell::sync::Lazy; use std::sync::LazyLock;
// ============= Telegram Datacenters ============= // ============= Telegram Datacenters =============
pub const TG_DATACENTER_PORT: u16 = 443; pub const TG_DATACENTER_PORT: u16 = 443;
pub static TG_DATACENTERS_V4: Lazy<Vec<IpAddr>> = Lazy::new(|| { pub static TG_DATACENTERS_V4: LazyLock<Vec<IpAddr>> = LazyLock::new(|| {
vec![ vec![
IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)),
IpAddr::V4(Ipv4Addr::new(149, 154, 167, 51)), IpAddr::V4(Ipv4Addr::new(149, 154, 167, 51)),
@@ -17,7 +17,7 @@ pub static TG_DATACENTERS_V4: Lazy<Vec<IpAddr>> = Lazy::new(|| {
] ]
}); });
pub static TG_DATACENTERS_V6: Lazy<Vec<IpAddr>> = Lazy::new(|| { pub static TG_DATACENTERS_V6: LazyLock<Vec<IpAddr>> = LazyLock::new(|| {
vec![ vec![
IpAddr::V6("2001:b28:f23d:f001::a".parse().unwrap()), IpAddr::V6("2001:b28:f23d:f001::a".parse().unwrap()),
IpAddr::V6("2001:67c:04e8:f002::a".parse().unwrap()), IpAddr::V6("2001:67c:04e8:f002::a".parse().unwrap()),
@@ -29,8 +29,8 @@ pub static TG_DATACENTERS_V6: Lazy<Vec<IpAddr>> = Lazy::new(|| {
// ============= Middle Proxies (for advertising) ============= // ============= Middle Proxies (for advertising) =============
pub static TG_MIDDLE_PROXIES_V4: Lazy<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> = pub static TG_MIDDLE_PROXIES_V4: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
Lazy::new(|| { LazyLock::new(|| {
let mut m = std::collections::HashMap::new(); let mut m = std::collections::HashMap::new();
m.insert(1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]); m.insert(1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
m.insert(-1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]); m.insert(-1, vec![(IpAddr::V4(Ipv4Addr::new(149, 154, 175, 50)), 8888)]);
@@ -45,8 +45,8 @@ pub static TG_MIDDLE_PROXIES_V4: Lazy<std::collections::HashMap<i32, Vec<(IpAddr
m m
}); });
pub static TG_MIDDLE_PROXIES_V6: Lazy<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> = pub static TG_MIDDLE_PROXIES_V6: LazyLock<std::collections::HashMap<i32, Vec<(IpAddr, u16)>>> =
Lazy::new(|| { LazyLock::new(|| {
let mut m = std::collections::HashMap::new(); let mut m = std::collections::HashMap::new();
m.insert(1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]); m.insert(1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
m.insert(-1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]); m.insert(-1, vec![(IpAddr::V6("2001:b28:f23d:f001::d".parse().unwrap()), 8888)]);
@@ -167,8 +167,6 @@ pub const DEFAULT_ACK_TIMEOUT_SECS: u64 = 300;
// ============= Buffer Sizes ============= // ============= Buffer Sizes =============
/// Default buffer size /// Default buffer size
/// CHANGED: Reduced from 64KB to 16KB to match TLS record size and align with
/// the new buffering strategy for better iOS upload performance.
pub const DEFAULT_BUFFER_SIZE: usize = 16384; pub const DEFAULT_BUFFER_SIZE: usize = 16384;
/// Small buffer size for bad client handling /// Small buffer size for bad client handling

View File

@@ -1,10 +1,13 @@
//! MTProto Obfuscation //! MTProto Obfuscation
use zeroize::Zeroize;
use crate::crypto::{sha256, AesCtr}; use crate::crypto::{sha256, AesCtr};
use crate::error::Result; use crate::error::Result;
use super::constants::*; use super::constants::*;
/// Obfuscation parameters from handshake /// Obfuscation parameters from handshake
///
/// Key material is zeroized on drop.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ObfuscationParams { pub struct ObfuscationParams {
/// Key for decrypting client -> proxy traffic /// Key for decrypting client -> proxy traffic
@@ -21,25 +24,31 @@ pub struct ObfuscationParams {
pub dc_idx: i16, pub dc_idx: i16,
} }
impl Drop for ObfuscationParams {
fn drop(&mut self) {
self.decrypt_key.zeroize();
self.decrypt_iv.zeroize();
self.encrypt_key.zeroize();
self.encrypt_iv.zeroize();
}
}
impl ObfuscationParams { impl ObfuscationParams {
/// Parse obfuscation parameters from handshake bytes /// Parse obfuscation parameters from handshake bytes
/// Returns None if handshake doesn't match any user secret /// Returns None if handshake doesn't match any user secret
pub fn from_handshake( pub fn from_handshake(
handshake: &[u8; HANDSHAKE_LEN], handshake: &[u8; HANDSHAKE_LEN],
secrets: &[(String, Vec<u8>)], // (username, secret_bytes) secrets: &[(String, Vec<u8>)],
) -> Option<(Self, String)> { ) -> Option<(Self, String)> {
// Extract prekey and IV for decryption
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
// Reversed for encryption direction
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect(); let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
for (username, secret) in secrets { for (username, secret) in secrets {
// Derive decryption key
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey); dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(secret); dec_key_input.extend_from_slice(secret);
@@ -47,26 +56,22 @@ impl ObfuscationParams {
let decrypt_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap()); let decrypt_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
// Create decryptor and decrypt handshake
let mut decryptor = AesCtr::new(&decrypt_key, decrypt_iv); let mut decryptor = AesCtr::new(&decrypt_key, decrypt_iv);
let decrypted = decryptor.decrypt(handshake); let decrypted = decryptor.decrypt(handshake);
// Check protocol tag
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
.try_into() .try_into()
.unwrap(); .unwrap();
let proto_tag = match ProtoTag::from_bytes(tag_bytes) { let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
Some(tag) => tag, Some(tag) => tag,
None => continue, // Try next secret None => continue,
}; };
// Extract DC index
let dc_idx = i16::from_le_bytes( let dc_idx = i16::from_le_bytes(
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap() decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
); );
// Derive encryption key
let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); let mut enc_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
enc_key_input.extend_from_slice(enc_prekey); enc_key_input.extend_from_slice(enc_prekey);
enc_key_input.extend_from_slice(secret); enc_key_input.extend_from_slice(secret);
@@ -123,18 +128,15 @@ pub fn generate_nonce<R: FnMut(usize) -> Vec<u8>>(mut random_bytes: R) -> [u8; H
/// Check if nonce is valid (not matching reserved patterns) /// Check if nonce is valid (not matching reserved patterns)
pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool { pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool {
// Check first byte
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) {
return false; return false;
} }
// Check first 4 bytes
let first_four: [u8; 4] = nonce[..4].try_into().unwrap(); let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { if RESERVED_NONCE_BEGINNINGS.contains(&first_four) {
return false; return false;
} }
// Check bytes 4-7
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap(); let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
if RESERVED_NONCE_CONTINUES.contains(&continue_four) { if RESERVED_NONCE_CONTINUES.contains(&continue_four) {
return false; return false;
@@ -147,12 +149,10 @@ pub fn is_valid_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> bool {
pub fn prepare_tg_nonce( pub fn prepare_tg_nonce(
nonce: &mut [u8; HANDSHAKE_LEN], nonce: &mut [u8; HANDSHAKE_LEN],
proto_tag: ProtoTag, proto_tag: ProtoTag,
enc_key_iv: Option<&[u8]>, // For fast mode enc_key_iv: Option<&[u8]>,
) { ) {
// Set protocol tag
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
// For fast mode, copy the reversed enc_key_iv
if let Some(key_iv) = enc_key_iv { if let Some(key_iv) = enc_key_iv {
let reversed: Vec<u8> = key_iv.iter().rev().copied().collect(); let reversed: Vec<u8> = key_iv.iter().rev().copied().collect();
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&reversed); nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN].copy_from_slice(&reversed);
@@ -161,14 +161,12 @@ pub fn prepare_tg_nonce(
/// Encrypt the outgoing nonce for Telegram /// Encrypt the outgoing nonce for Telegram
pub fn encrypt_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> { pub fn encrypt_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
// Derive encryption key from the nonce itself
let key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let enc_key = sha256(key_iv); let enc_key = sha256(key_iv);
let enc_iv = u128::from_be_bytes(key_iv[..IV_LEN].try_into().unwrap()); let enc_iv = u128::from_be_bytes(key_iv[..IV_LEN].try_into().unwrap());
let mut encryptor = AesCtr::new(&enc_key, enc_iv); let mut encryptor = AesCtr::new(&enc_key, enc_iv);
// Only encrypt from PROTO_TAG_POS onwards
let mut result = nonce.to_vec(); let mut result = nonce.to_vec();
let encrypted_part = encryptor.encrypt(&nonce[PROTO_TAG_POS..]); let encrypted_part = encryptor.encrypt(&nonce[PROTO_TAG_POS..]);
result[PROTO_TAG_POS..].copy_from_slice(&encrypted_part); result[PROTO_TAG_POS..].copy_from_slice(&encrypted_part);
@@ -182,22 +180,18 @@ mod tests {
#[test] #[test]
fn test_is_valid_nonce() { fn test_is_valid_nonce() {
// Valid nonce
let mut valid = [0x42u8; HANDSHAKE_LEN]; let mut valid = [0x42u8; HANDSHAKE_LEN];
valid[4..8].copy_from_slice(&[1, 2, 3, 4]); valid[4..8].copy_from_slice(&[1, 2, 3, 4]);
assert!(is_valid_nonce(&valid)); assert!(is_valid_nonce(&valid));
// Invalid: starts with 0xef
let mut invalid = [0x00u8; HANDSHAKE_LEN]; let mut invalid = [0x00u8; HANDSHAKE_LEN];
invalid[0] = 0xef; invalid[0] = 0xef;
assert!(!is_valid_nonce(&invalid)); assert!(!is_valid_nonce(&invalid));
// Invalid: starts with HEAD
let mut invalid = [0x00u8; HANDSHAKE_LEN]; let mut invalid = [0x00u8; HANDSHAKE_LEN];
invalid[..4].copy_from_slice(b"HEAD"); invalid[..4].copy_from_slice(b"HEAD");
assert!(!is_valid_nonce(&invalid)); assert!(!is_valid_nonce(&invalid));
// Invalid: bytes 4-7 are zeros
let mut invalid = [0x42u8; HANDSHAKE_LEN]; let mut invalid = [0x42u8; HANDSHAKE_LEN];
invalid[4..8].copy_from_slice(&[0, 0, 0, 0]); invalid[4..8].copy_from_slice(&[0, 0, 0, 0]);
assert!(!is_valid_nonce(&invalid)); assert!(!is_valid_nonce(&invalid));

View File

@@ -4,7 +4,7 @@
//! for domain fronting. The handshake looks like valid TLS 1.3 but //! for domain fronting. The handshake looks like valid TLS 1.3 but
//! actually carries MTProto authentication data. //! actually carries MTProto authentication data.
use crate::crypto::{sha256_hmac, random::SECURE_RANDOM}; use crate::crypto::{sha256_hmac, SecureRandom};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use super::constants::*; use super::constants::*;
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{SystemTime, UNIX_EPOCH};
@@ -315,8 +315,8 @@ pub fn validate_tls_handshake(
/// ///
/// This generates random bytes that look like a valid X25519 public key. /// This generates random bytes that look like a valid X25519 public key.
/// Since we're not doing real TLS, the actual cryptographic properties don't matter. /// Since we're not doing real TLS, the actual cryptographic properties don't matter.
pub fn gen_fake_x25519_key() -> [u8; 32] { pub fn gen_fake_x25519_key(rng: &SecureRandom) -> [u8; 32] {
let bytes = SECURE_RANDOM.bytes(32); let bytes = rng.bytes(32);
bytes.try_into().unwrap() bytes.try_into().unwrap()
} }
@@ -333,8 +333,9 @@ pub fn build_server_hello(
client_digest: &[u8; TLS_DIGEST_LEN], client_digest: &[u8; TLS_DIGEST_LEN],
session_id: &[u8], session_id: &[u8],
fake_cert_len: usize, fake_cert_len: usize,
rng: &SecureRandom,
) -> Vec<u8> { ) -> Vec<u8> {
let x25519_key = gen_fake_x25519_key(); let x25519_key = gen_fake_x25519_key(rng);
// Build ServerHello // Build ServerHello
let server_hello = ServerHelloBuilder::new(session_id.to_vec()) let server_hello = ServerHelloBuilder::new(session_id.to_vec())
@@ -351,7 +352,7 @@ pub fn build_server_hello(
]; ];
// Build fake certificate (Application Data record) // Build fake certificate (Application Data record)
let fake_cert = SECURE_RANDOM.bytes(fake_cert_len); let fake_cert = rng.bytes(fake_cert_len);
let mut app_data_record = Vec::with_capacity(5 + fake_cert_len); let mut app_data_record = Vec::with_capacity(5 + fake_cert_len);
app_data_record.push(TLS_RECORD_APPLICATION); app_data_record.push(TLS_RECORD_APPLICATION);
app_data_record.extend_from_slice(&TLS_VERSION); app_data_record.extend_from_slice(&TLS_VERSION);
@@ -489,8 +490,9 @@ mod tests {
#[test] #[test]
fn test_gen_fake_x25519_key() { fn test_gen_fake_x25519_key() {
let key1 = gen_fake_x25519_key(); let rng = SecureRandom::new();
let key2 = gen_fake_x25519_key(); let key1 = gen_fake_x25519_key(&rng);
let key2 = gen_fake_x25519_key(&rng);
assert_eq!(key1.len(), 32); assert_eq!(key1.len(), 32);
assert_eq!(key2.len(), 32); assert_eq!(key2.len(), 32);
@@ -545,7 +547,8 @@ mod tests {
let client_digest = [0x42u8; 32]; let client_digest = [0x42u8; 32];
let session_id = vec![0xAA; 32]; let session_id = vec![0xAA; 32];
let response = build_server_hello(secret, &client_digest, &session_id, 2048); let rng = SecureRandom::new();
let response = build_server_hello(secret, &client_digest, &session_id, 2048, &rng);
// Should have at least 3 records // Should have at least 3 records
assert!(response.len() > 100); assert!(response.len() > 100);
@@ -577,8 +580,9 @@ mod tests {
let client_digest = [0x42u8; 32]; let client_digest = [0x42u8; 32];
let session_id = vec![0xAA; 32]; let session_id = vec![0xAA; 32];
let response1 = build_server_hello(secret, &client_digest, &session_id, 1024); let rng = SecureRandom::new();
let response2 = build_server_hello(secret, &client_digest, &session_id, 1024); let response1 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng);
let response2 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng);
// Digest position should have non-zero data // Digest position should have non-zero data
let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]; let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN];

View File

@@ -14,20 +14,18 @@ use crate::protocol::constants::*;
use crate::protocol::tls; use crate::protocol::tls;
use crate::stats::{Stats, ReplayChecker}; use crate::stats::{Stats, ReplayChecker};
use crate::transport::{configure_client_socket, UpstreamManager}; use crate::transport::{configure_client_socket, UpstreamManager};
use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter}; use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter, BufferPool};
use crate::crypto::AesCtr; use crate::crypto::{AesCtr, SecureRandom};
use super::handshake::{ use crate::proxy::handshake::{
handle_tls_handshake, handle_mtproto_handshake, handle_tls_handshake, handle_mtproto_handshake,
HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce, HandshakeSuccess, generate_tg_nonce, encrypt_tg_nonce,
}; };
use super::relay::relay_bidirectional; use crate::proxy::relay::relay_bidirectional;
use super::masking::handle_bad_client; use crate::proxy::masking::handle_bad_client;
/// Client connection handler (builder struct)
pub struct ClientHandler; pub struct ClientHandler;
/// Running client handler with stream and context
pub struct RunningClientHandler { pub struct RunningClientHandler {
stream: TcpStream, stream: TcpStream,
peer: SocketAddr, peer: SocketAddr,
@@ -35,59 +33,47 @@ pub struct RunningClientHandler {
stats: Arc<Stats>, stats: Arc<Stats>,
replay_checker: Arc<ReplayChecker>, replay_checker: Arc<ReplayChecker>,
upstream_manager: Arc<UpstreamManager>, upstream_manager: Arc<UpstreamManager>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
} }
impl ClientHandler { impl ClientHandler {
/// Create new client handler instance
pub fn new( pub fn new(
stream: TcpStream, stream: TcpStream,
peer: SocketAddr, peer: SocketAddr,
config: Arc<ProxyConfig>, config: Arc<ProxyConfig>,
stats: Arc<Stats>, stats: Arc<Stats>,
upstream_manager: Arc<UpstreamManager>, upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>, // CHANGED: Accept global checker replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
) -> RunningClientHandler { ) -> RunningClientHandler {
// CHANGED: Removed local creation of ReplayChecker.
// It is now passed from main.rs to ensure global replay protection.
RunningClientHandler { RunningClientHandler {
stream, stream, peer, config, stats, replay_checker,
peer, upstream_manager, buffer_pool, rng,
config,
stats,
replay_checker,
upstream_manager,
} }
} }
} }
impl RunningClientHandler { impl RunningClientHandler {
/// Run the client handler
pub async fn run(mut self) -> Result<()> { pub async fn run(mut self) -> Result<()> {
self.stats.increment_connects_all(); self.stats.increment_connects_all();
let peer = self.peer; let peer = self.peer;
debug!(peer = %peer, "New connection"); debug!(peer = %peer, "New connection");
// Configure socket
if let Err(e) = configure_client_socket( if let Err(e) = configure_client_socket(
&self.stream, &self.stream,
self.config.client_keepalive, self.config.timeouts.client_keepalive,
self.config.client_ack_timeout, self.config.timeouts.client_ack,
) { ) {
debug!(peer = %peer, error = %e, "Failed to configure client socket"); debug!(peer = %peer, error = %e, "Failed to configure client socket");
} }
// Perform handshake with timeout let handshake_timeout = Duration::from_secs(self.config.timeouts.client_handshake);
let handshake_timeout = Duration::from_secs(self.config.client_handshake_timeout);
// Clone stats for error handling block
let stats = self.stats.clone(); let stats = self.stats.clone();
let result = timeout( let result = timeout(handshake_timeout, self.do_handshake()).await;
handshake_timeout,
self.do_handshake()
).await;
match result { match result {
Ok(Ok(())) => { Ok(Ok(())) => {
@@ -106,16 +92,14 @@ impl RunningClientHandler {
} }
} }
/// Perform handshake and relay
async fn do_handshake(mut self) -> Result<()> { async fn do_handshake(mut self) -> Result<()> {
// Read first bytes to determine handshake type
let mut first_bytes = [0u8; 5]; let mut first_bytes = [0u8; 5];
self.stream.read_exact(&mut first_bytes).await?; self.stream.read_exact(&mut first_bytes).await?;
let is_tls = tls::is_tls_handshake(&first_bytes[..3]); let is_tls = tls::is_tls_handshake(&first_bytes[..3]);
let peer = self.peer; let peer = self.peer;
debug!(peer = %peer, is_tls = is_tls, first_bytes = %hex::encode(&first_bytes), "Handshake type detected"); debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
if is_tls { if is_tls {
self.handle_tls_client(first_bytes).await self.handle_tls_client(first_bytes).await
@@ -124,14 +108,9 @@ impl RunningClientHandler {
} }
} }
/// Handle TLS-wrapped client async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
async fn handle_tls_client(
mut self,
first_bytes: [u8; 5],
) -> Result<()> {
let peer = self.peer; let peer = self.peer;
// Read TLS handshake length
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake"); debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake");
@@ -139,138 +118,102 @@ impl RunningClientHandler {
if tls_len < 512 { if tls_len < 512 {
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
self.stats.increment_connects_bad(); self.stats.increment_connects_bad();
handle_bad_client(self.stream, &first_bytes, &self.config).await; let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(()); return Ok(());
} }
// Read full TLS handshake
let mut handshake = vec![0u8; 5 + tls_len]; let mut handshake = vec![0u8; 5 + tls_len];
handshake[..5].copy_from_slice(&first_bytes); handshake[..5].copy_from_slice(&first_bytes);
self.stream.read_exact(&mut handshake[5..]).await?; self.stream.read_exact(&mut handshake[5..]).await?;
// Extract fields before consuming self.stream
let config = self.config.clone(); let config = self.config.clone();
let replay_checker = self.replay_checker.clone(); let replay_checker = self.replay_checker.clone();
let stats = self.stats.clone(); let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone();
// Split stream for reading/writing
let (read_half, write_half) = self.stream.into_split(); let (read_half, write_half) = self.stream.into_split();
// Handle TLS handshake
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
&handshake, &handshake, read_half, write_half, peer,
read_half, &config, &replay_checker, &self.rng,
write_half,
peer,
&config,
&replay_checker,
).await { ).await {
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient => { HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad(); stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(()); return Ok(());
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
// Read MTProto handshake through TLS
debug!(peer = %peer, "Reading MTProto handshake through TLS"); debug!(peer = %peer, "Reading MTProto handshake through TLS");
let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?; let mtproto_data = tls_reader.read_exact(HANDSHAKE_LEN).await?;
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into() let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into()
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?; .map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
// Handle MTProto handshake
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
&mtproto_handshake, &mtproto_handshake, tls_reader, tls_writer, peer,
tls_reader, &config, &replay_checker, true,
tls_writer,
peer,
&config,
&replay_checker,
true,
).await { ).await {
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient => { HandshakeResult::BadClient { reader: _, writer: _ } => {
stats.increment_connects_bad(); stats.increment_connects_bad();
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
return Ok(()); return Ok(());
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
// Handle authenticated client
// We can't use self.handle_authenticated_inner because self is partially moved
// So we call it as an associated function or method on a new struct,
// or just inline the logic / use a static method.
// Since handle_authenticated_inner needs self.upstream_manager and self.stats,
// we should pass them explicitly.
Self::handle_authenticated_static( Self::handle_authenticated_static(
crypto_reader, crypto_reader, crypto_writer, success,
crypto_writer, self.upstream_manager, self.stats, self.config,
success, buffer_pool, self.rng,
self.upstream_manager,
self.stats,
self.config
).await ).await
} }
/// Handle direct (non-TLS) client async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<()> {
async fn handle_direct_client(
mut self,
first_bytes: [u8; 5],
) -> Result<()> {
let peer = self.peer; let peer = self.peer;
// Check if non-TLS modes are enabled if !self.config.general.modes.classic && !self.config.general.modes.secure {
if !self.config.modes.classic && !self.config.modes.secure {
debug!(peer = %peer, "Non-TLS modes disabled"); debug!(peer = %peer, "Non-TLS modes disabled");
self.stats.increment_connects_bad(); self.stats.increment_connects_bad();
handle_bad_client(self.stream, &first_bytes, &self.config).await; let (reader, writer) = self.stream.into_split();
handle_bad_client(reader, writer, &first_bytes, &self.config).await;
return Ok(()); return Ok(());
} }
// Read rest of handshake
let mut handshake = [0u8; HANDSHAKE_LEN]; let mut handshake = [0u8; HANDSHAKE_LEN];
handshake[..5].copy_from_slice(&first_bytes); handshake[..5].copy_from_slice(&first_bytes);
self.stream.read_exact(&mut handshake[5..]).await?; self.stream.read_exact(&mut handshake[5..]).await?;
// Extract fields
let config = self.config.clone(); let config = self.config.clone();
let replay_checker = self.replay_checker.clone(); let replay_checker = self.replay_checker.clone();
let stats = self.stats.clone(); let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone();
// Split stream
let (read_half, write_half) = self.stream.into_split(); let (read_half, write_half) = self.stream.into_split();
// Handle MTProto handshake
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
&handshake, &handshake, read_half, write_half, peer,
read_half, &config, &replay_checker, false,
write_half,
peer,
&config,
&replay_checker,
false,
).await { ).await {
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient => { HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad(); stats.increment_connects_bad();
handle_bad_client(reader, writer, &handshake, &config).await;
return Ok(()); return Ok(());
} }
HandshakeResult::Error(e) => return Err(e), HandshakeResult::Error(e) => return Err(e),
}; };
Self::handle_authenticated_static( Self::handle_authenticated_static(
crypto_reader, crypto_reader, crypto_writer, success,
crypto_writer, self.upstream_manager, self.stats, self.config,
success, buffer_pool, self.rng,
self.upstream_manager,
self.stats,
self.config
).await ).await
} }
/// Static version of handle_authenticated_inner to avoid ownership issues
async fn handle_authenticated_static<R, W>( async fn handle_authenticated_static<R, W>(
client_reader: CryptoReader<R>, client_reader: CryptoReader<R>,
client_writer: CryptoWriter<W>, client_writer: CryptoWriter<W>,
@@ -278,6 +221,8 @@ impl RunningClientHandler {
upstream_manager: Arc<UpstreamManager>, upstream_manager: Arc<UpstreamManager>,
stats: Arc<Stats>, stats: Arc<Stats>,
config: Arc<ProxyConfig>, config: Arc<ProxyConfig>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
) -> Result<()> ) -> Result<()>
where where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
@@ -285,13 +230,11 @@ impl RunningClientHandler {
{ {
let user = &success.user; let user = &success.user;
// Check user limits
if let Err(e) = Self::check_user_limits_static(user, &config, &stats) { if let Err(e) = Self::check_user_limits_static(user, &config, &stats) {
warn!(user = %user, error = %e, "User limit exceeded"); warn!(user = %user, error = %e, "User limit exceeded");
return Err(e); return Err(e);
} }
// Get datacenter address
let dc_addr = Self::get_dc_addr_static(success.dc_idx, &config)?; let dc_addr = Self::get_dc_addr_static(success.dc_idx, &config)?;
info!( info!(
@@ -300,70 +243,54 @@ impl RunningClientHandler {
dc = success.dc_idx, dc = success.dc_idx,
dc_addr = %dc_addr, dc_addr = %dc_addr,
proto = ?success.proto_tag, proto = ?success.proto_tag,
fast_mode = config.fast_mode,
"Connecting to Telegram" "Connecting to Telegram"
); );
// Connect to Telegram via UpstreamManager // Pass dc_idx for latency-based upstream selection
let tg_stream = upstream_manager.connect(dc_addr).await?; let tg_stream = upstream_manager.connect(dc_addr, Some(success.dc_idx)).await?;
debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected to Telegram, performing handshake"); debug!(peer = %success.peer, dc_addr = %dc_addr, "Connected, performing TG handshake");
// Perform Telegram handshake and get crypto streams
let (tg_reader, tg_writer) = Self::do_tg_handshake_static( let (tg_reader, tg_writer) = Self::do_tg_handshake_static(
tg_stream, tg_stream, &success, &config, rng.as_ref(),
&success,
&config,
).await?; ).await?;
debug!(peer = %success.peer, "Telegram handshake complete, starting relay"); debug!(peer = %success.peer, "TG handshake complete, starting relay");
// Update stats
stats.increment_user_connects(user); stats.increment_user_connects(user);
stats.increment_user_curr_connects(user); stats.increment_user_curr_connects(user);
// Relay traffic
let relay_result = relay_bidirectional( let relay_result = relay_bidirectional(
client_reader, client_reader, client_writer,
client_writer, tg_reader, tg_writer,
tg_reader, user, Arc::clone(&stats), buffer_pool,
tg_writer,
user,
Arc::clone(&stats),
).await; ).await;
// Update stats
stats.decrement_user_curr_connects(user); stats.decrement_user_curr_connects(user);
match &relay_result { match &relay_result {
Ok(()) => debug!(user = %user, peer = %success.peer, "Relay completed normally"), Ok(()) => debug!(user = %user, "Relay completed"),
Err(e) => debug!(user = %user, peer = %success.peer, error = %e, "Relay ended with error"), Err(e) => debug!(user = %user, error = %e, "Relay ended with error"),
} }
relay_result relay_result
} }
/// Check user limits (static version)
fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> { fn check_user_limits_static(user: &str, config: &ProxyConfig, stats: &Stats) -> Result<()> {
// Check expiration if let Some(expiration) = config.access.user_expirations.get(user) {
if let Some(expiration) = config.user_expirations.get(user) {
if chrono::Utc::now() > *expiration { if chrono::Utc::now() > *expiration {
return Err(ProxyError::UserExpired { user: user.to_string() }); return Err(ProxyError::UserExpired { user: user.to_string() });
} }
} }
// Check connection limit if let Some(limit) = config.access.user_max_tcp_conns.get(user) {
if let Some(limit) = config.user_max_tcp_conns.get(user) { if stats.get_user_curr_connects(user) >= *limit as u64 {
let current = stats.get_user_curr_connects(user);
if current >= *limit as u64 {
return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() }); return Err(ProxyError::ConnectionLimitExceeded { user: user.to_string() });
} }
} }
// Check data quota if let Some(quota) = config.access.user_data_quota.get(user) {
if let Some(quota) = config.user_data_quota.get(user) { if stats.get_user_total_octets(user) >= *quota {
let used = stats.get_user_total_octets(user);
if used >= *quota {
return Err(ProxyError::DataQuotaExceeded { user: user.to_string() }); return Err(ProxyError::DataQuotaExceeded { user: user.to_string() });
} }
} }
@@ -371,62 +298,105 @@ impl RunningClientHandler {
Ok(()) Ok(())
} }
/// Get datacenter address by index (static version) /// Resolve DC index to a target address.
///
/// Matches the C implementation's behavior exactly:
///
/// 1. Look up DC in known clusters (standard DCs ±1..±5)
/// 2. If not found and `force=1` → fall back to `default_cluster`
///
/// In the C code:
/// - `proxy-multi.conf` is downloaded from Telegram, contains only DC ±1..±5
/// - `default 2;` directive sets the default cluster
/// - `mf_cluster_lookup(CurConf, target_dc, 1)` returns default_cluster
/// for any unknown DC (like CDN DC 203)
///
/// So DC 203, DC 101, DC -300, etc. all route to the default DC (2).
/// There is NO modular arithmetic in the C implementation.
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> { fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
let idx = (dc_idx.abs() - 1) as usize; let datacenters = if config.general.prefer_ipv6 {
let datacenters = if config.prefer_ipv6 {
&*TG_DATACENTERS_V6 &*TG_DATACENTERS_V6
} else { } else {
&*TG_DATACENTERS_V4 &*TG_DATACENTERS_V4
}; };
datacenters.get(idx) let num_dcs = datacenters.len(); // 5
.map(|ip| SocketAddr::new(*ip, TG_DATACENTER_PORT))
.ok_or_else(|| ProxyError::InvalidHandshake( // === Step 1: Check dc_overrides (like C's `proxy_for <dc> <ip>:<port>`) ===
format!("Invalid DC index: {}", dc_idx) let dc_key = dc_idx.to_string();
)) if let Some(addr_str) = config.dc_overrides.get(&dc_key) {
match addr_str.parse::<SocketAddr>() {
Ok(addr) => {
debug!(dc_idx = dc_idx, addr = %addr, "Using DC override from config");
return Ok(addr);
}
Err(_) => {
warn!(dc_idx = dc_idx, addr_str = %addr_str,
"Invalid DC override address in config, ignoring");
}
}
}
// === Step 2: Standard DCs ±1..±5 — direct lookup ===
let abs_dc = dc_idx.unsigned_abs() as usize;
if abs_dc >= 1 && abs_dc <= num_dcs {
return Ok(SocketAddr::new(datacenters[abs_dc - 1], TG_DATACENTER_PORT));
}
// === Step 3: Unknown DC — fall back to default_cluster ===
// Exactly like C's `mf_cluster_lookup(CurConf, target_dc, force=1)`
// which returns `MC->default_cluster` when the DC is not found.
// Telegram's proxy-multi.conf uses `default 2;`
let default_dc = config.default_dc.unwrap_or(2) as usize;
let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs {
default_dc - 1
} else {
1 // DC 2 (index 1) — matches Telegram's `default 2;`
};
info!(
original_dc = dc_idx,
fallback_dc = (fallback_idx + 1) as u16,
fallback_addr = %datacenters[fallback_idx],
"Special DC ---> default_cluster"
);
Ok(SocketAddr::new(datacenters[fallback_idx], TG_DATACENTER_PORT))
} }
/// Perform handshake with Telegram server (static version)
async fn do_tg_handshake_static( async fn do_tg_handshake_static(
mut stream: TcpStream, mut stream: TcpStream,
success: &HandshakeSuccess, success: &HandshakeSuccess,
config: &ProxyConfig, config: &ProxyConfig,
rng: &SecureRandom,
) -> Result<(CryptoReader<tokio::net::tcp::OwnedReadHalf>, CryptoWriter<tokio::net::tcp::OwnedWriteHalf>)> { ) -> Result<(CryptoReader<tokio::net::tcp::OwnedReadHalf>, CryptoWriter<tokio::net::tcp::OwnedWriteHalf>)> {
// Generate nonce with keys for TG
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce( let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce(
success.proto_tag, success.proto_tag,
&success.dec_key, // Client's dec key &success.dec_key,
success.dec_iv, success.dec_iv,
config.fast_mode, rng,
config.general.fast_mode,
); );
// Encrypt nonce
let encrypted_nonce = encrypt_tg_nonce(&nonce); let encrypted_nonce = encrypt_tg_nonce(&nonce);
debug!( debug!(
peer = %success.peer, peer = %success.peer,
nonce_head = %hex::encode(&nonce[..16]), nonce_head = %hex::encode(&nonce[..16]),
encrypted_head = %hex::encode(&encrypted_nonce[..16]),
"Sending nonce to Telegram" "Sending nonce to Telegram"
); );
// Send to Telegram
stream.write_all(&encrypted_nonce).await?; stream.write_all(&encrypted_nonce).await?;
stream.flush().await?; stream.flush().await?;
debug!(peer = %success.peer, "Nonce sent to Telegram");
// Split stream and wrap with crypto
let (read_half, write_half) = stream.into_split(); let (read_half, write_half) = stream.into_split();
let decryptor = AesCtr::new(&tg_dec_key, tg_dec_iv); let decryptor = AesCtr::new(&tg_dec_key, tg_dec_iv);
let encryptor = AesCtr::new(&tg_enc_key, tg_enc_iv); let encryptor = AesCtr::new(&tg_enc_key, tg_enc_iv);
let tg_reader = CryptoReader::new(read_half, decryptor); Ok((
let tg_writer = CryptoWriter::new(write_half, encryptor); CryptoReader::new(read_half, decryptor),
CryptoWriter::new(write_half, encryptor),
Ok((tg_reader, tg_writer)) ))
} }
} }

View File

@@ -1,11 +1,11 @@
//! MTProto Handshake Magics //! MTProto Handshake
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tracing::{debug, warn, trace, info}; use tracing::{debug, warn, trace, info};
use zeroize::Zeroize;
use crate::crypto::{sha256, AesCtr}; use crate::crypto::{sha256, AesCtr, SecureRandom};
use crate::crypto::random::SECURE_RANDOM;
use crate::protocol::constants::*; use crate::protocol::constants::*;
use crate::protocol::tls; use crate::protocol::tls;
use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter}; use crate::stream::{FakeTlsReader, FakeTlsWriter, CryptoReader, CryptoWriter};
@@ -14,6 +14,9 @@ use crate::stats::ReplayChecker;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
/// Result of successful handshake /// Result of successful handshake
///
/// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is
/// zeroized on drop.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct HandshakeSuccess { pub struct HandshakeSuccess {
/// Authenticated user name /// Authenticated user name
@@ -34,6 +37,15 @@ pub struct HandshakeSuccess {
pub is_tls: bool, pub is_tls: bool,
} }
impl Drop for HandshakeSuccess {
fn drop(&mut self) {
self.dec_key.zeroize();
self.dec_iv.zeroize();
self.enc_key.zeroize();
self.enc_iv.zeroize();
}
}
/// Handle fake TLS handshake /// Handle fake TLS handshake
pub async fn handle_tls_handshake<R, W>( pub async fn handle_tls_handshake<R, W>(
handshake: &[u8], handshake: &[u8],
@@ -42,76 +54,74 @@ pub async fn handle_tls_handshake<R, W>(
peer: SocketAddr, peer: SocketAddr,
config: &ProxyConfig, config: &ProxyConfig,
replay_checker: &ReplayChecker, replay_checker: &ReplayChecker,
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String)> rng: &SecureRandom,
) -> HandshakeResult<(FakeTlsReader<R>, FakeTlsWriter<W>, String), R, W>
where where
R: AsyncRead + Unpin, R: AsyncRead + Unpin,
W: AsyncWrite + Unpin, W: AsyncWrite + Unpin,
{ {
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake"); debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
// Check minimum length
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
debug!(peer = %peer, "TLS handshake too short"); debug!(peer = %peer, "TLS handshake too short");
return HandshakeResult::BadClient; return HandshakeResult::BadClient { reader, writer };
} }
// Extract digest for replay check
let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]; let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN];
let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN]; let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN];
// Check for replay
if replay_checker.check_tls_digest(digest_half) { if replay_checker.check_tls_digest(digest_half) {
warn!(peer = %peer, "TLS replay attack detected"); warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient; return HandshakeResult::BadClient { reader, writer };
} }
// Build secrets list let secrets: Vec<(String, Vec<u8>)> = config.access.users.iter()
let secrets: Vec<(String, Vec<u8>)> = config.users.iter()
.filter_map(|(name, hex)| { .filter_map(|(name, hex)| {
hex::decode(hex).ok().map(|bytes| (name.clone(), bytes)) hex::decode(hex).ok().map(|bytes| (name.clone(), bytes))
}) })
.collect(); .collect();
debug!(peer = %peer, num_users = secrets.len(), "Validating TLS handshake against users");
// Validate handshake
let validation = match tls::validate_tls_handshake( let validation = match tls::validate_tls_handshake(
handshake, handshake,
&secrets, &secrets,
config.ignore_time_skew, config.access.ignore_time_skew,
) { ) {
Some(v) => v, Some(v) => v,
None => { None => {
debug!(peer = %peer, "TLS handshake validation failed - no matching user"); debug!(
return HandshakeResult::BadClient; peer = %peer,
ignore_time_skew = config.access.ignore_time_skew,
"TLS handshake validation failed - no matching user or time skew"
);
return HandshakeResult::BadClient { reader, writer };
} }
}; };
// Get secret for response
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
Some((_, s)) => s, Some((_, s)) => s,
None => return HandshakeResult::BadClient, None => return HandshakeResult::BadClient { reader, writer },
}; };
// Build and send response
let response = tls::build_server_hello( let response = tls::build_server_hello(
secret, secret,
&validation.digest, &validation.digest,
&validation.session_id, &validation.session_id,
config.fake_cert_len, config.censorship.fake_cert_len,
rng,
); );
debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello"); debug!(peer = %peer, response_len = response.len(), "Sending TLS ServerHello");
if let Err(e) = writer.write_all(&response).await { if let Err(e) = writer.write_all(&response).await {
warn!(peer = %peer, error = %e, "Failed to write TLS ServerHello");
return HandshakeResult::Error(ProxyError::Io(e)); return HandshakeResult::Error(ProxyError::Io(e));
} }
if let Err(e) = writer.flush().await { if let Err(e) = writer.flush().await {
warn!(peer = %peer, error = %e, "Failed to flush TLS ServerHello");
return HandshakeResult::Error(ProxyError::Io(e)); return HandshakeResult::Error(ProxyError::Io(e));
} }
// Record for replay protection
replay_checker.add_tls_digest(digest_half); replay_checker.add_tls_digest(digest_half);
info!( info!(
@@ -136,39 +146,28 @@ pub async fn handle_mtproto_handshake<R, W>(
config: &ProxyConfig, config: &ProxyConfig,
replay_checker: &ReplayChecker, replay_checker: &ReplayChecker,
is_tls: bool, is_tls: bool,
) -> HandshakeResult<(CryptoReader<R>, CryptoWriter<W>, HandshakeSuccess)> ) -> HandshakeResult<(CryptoReader<R>, CryptoWriter<W>, HandshakeSuccess), R, W>
where where
R: AsyncRead + Unpin + Send, R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send, W: AsyncWrite + Unpin + Send,
{ {
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes"); trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
// Extract prekey and IV
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
debug!(
peer = %peer,
dec_prekey_iv = %hex::encode(dec_prekey_iv),
"Extracted prekey+IV from handshake"
);
// Check for replay
if replay_checker.check_handshake(dec_prekey_iv) { if replay_checker.check_handshake(dec_prekey_iv) {
warn!(peer = %peer, "MTProto replay attack detected"); warn!(peer = %peer, "MTProto replay attack detected");
return HandshakeResult::BadClient; return HandshakeResult::BadClient { reader, writer };
} }
// Reversed for encryption direction
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect(); let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
// Try each user's secret for (user, secret_hex) in &config.access.users {
for (user, secret_hex) in &config.users {
let secret = match hex::decode(secret_hex) { let secret = match hex::decode(secret_hex) {
Ok(s) => s, Ok(s) => s,
Err(_) => continue, Err(_) => continue,
}; };
// Derive decryption key
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
@@ -179,38 +178,23 @@ where
let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap()); let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
// Decrypt handshake to check protocol tag
let mut decryptor = AesCtr::new(&dec_key, dec_iv); let mut decryptor = AesCtr::new(&dec_key, dec_iv);
let decrypted = decryptor.decrypt(handshake); let decrypted = decryptor.decrypt(handshake);
trace!(
peer = %peer,
user = %user,
decrypted_tail = %hex::encode(&decrypted[PROTO_TAG_POS..]),
"Decrypted handshake tail"
);
// Check protocol tag
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4] let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
.try_into() .try_into()
.unwrap(); .unwrap();
let proto_tag = match ProtoTag::from_bytes(tag_bytes) { let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
Some(tag) => tag, Some(tag) => tag,
None => { None => continue,
trace!(peer = %peer, user = %user, tag = %hex::encode(tag_bytes), "Invalid proto tag");
continue;
}
}; };
debug!(peer = %peer, user = %user, proto = ?proto_tag, "Found valid proto tag");
// Check if mode is enabled
let mode_ok = match proto_tag { let mode_ok = match proto_tag {
ProtoTag::Secure => { ProtoTag::Secure => {
if is_tls { config.modes.tls } else { config.modes.secure } if is_tls { config.general.modes.tls } else { config.general.modes.secure }
} }
ProtoTag::Intermediate | ProtoTag::Abridged => config.modes.classic, ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic,
}; };
if !mode_ok { if !mode_ok {
@@ -218,12 +202,10 @@ where
continue; continue;
} }
// Extract DC index
let dc_idx = i16::from_le_bytes( let dc_idx = i16::from_le_bytes(
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap() decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
); );
// Derive encryption key
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN]; let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..]; let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
@@ -234,10 +216,8 @@ where
let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap()); let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap());
// Record for replay protection
replay_checker.add_handshake(dec_prekey_iv); replay_checker.add_handshake(dec_prekey_iv);
// Create new cipher instances
let decryptor = AesCtr::new(&dec_key, dec_iv); let decryptor = AesCtr::new(&dec_key, dec_iv);
let encryptor = AesCtr::new(&enc_key, enc_iv); let encryptor = AesCtr::new(&enc_key, enc_iv);
@@ -270,56 +250,37 @@ where
} }
debug!(peer = %peer, "MTProto handshake: no matching user found"); debug!(peer = %peer, "MTProto handshake: no matching user found");
HandshakeResult::BadClient HandshakeResult::BadClient { reader, writer }
} }
/// Generate nonce for Telegram connection /// Generate nonce for Telegram connection
///
/// In FAST MODE: we use the same keys for TG as for client, but reversed.
/// This means: client's enc_key becomes TG's dec_key and vice versa.
pub fn generate_tg_nonce( pub fn generate_tg_nonce(
proto_tag: ProtoTag, proto_tag: ProtoTag,
client_dec_key: &[u8; 32], client_dec_key: &[u8; 32],
client_dec_iv: u128, client_dec_iv: u128,
rng: &SecureRandom,
fast_mode: bool, fast_mode: bool,
) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) { ) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) {
loop { loop {
let bytes = SECURE_RANDOM.bytes(HANDSHAKE_LEN); let bytes = rng.bytes(HANDSHAKE_LEN);
let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap(); let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap();
// Check reserved patterns if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; }
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) {
continue;
}
let first_four: [u8; 4] = nonce[..4].try_into().unwrap(); let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { continue; }
continue;
}
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap(); let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
if RESERVED_NONCE_CONTINUES.contains(&continue_four) { if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; }
continue;
}
// Set protocol tag
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
// Fast mode: copy client's dec_key+iv (this becomes TG's enc direction)
// In fast mode, we make TG use the same keys as client but swapped:
// - What we decrypt FROM TG = what we encrypt TO client (so no re-encryption needed)
// - What we encrypt TO TG = what we decrypt FROM client
if fast_mode { if fast_mode {
// Put client's dec_key + dec_iv into nonce[8:56]
// This will be used by TG for encryption TO us
nonce[SKIP_LEN..SKIP_LEN + KEY_LEN].copy_from_slice(client_dec_key); nonce[SKIP_LEN..SKIP_LEN + KEY_LEN].copy_from_slice(client_dec_key);
nonce[SKIP_LEN + KEY_LEN..SKIP_LEN + KEY_LEN + IV_LEN] nonce[SKIP_LEN + KEY_LEN..SKIP_LEN + KEY_LEN + IV_LEN]
.copy_from_slice(&client_dec_iv.to_be_bytes()); .copy_from_slice(&client_dec_iv.to_be_bytes());
} }
// Now compute what keys WE will use for TG connection
// enc_key_iv = nonce[8:56] (for encrypting TO TG)
// dec_key_iv = nonce[8:56] reversed (for decrypting FROM TG)
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect(); let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
@@ -329,44 +290,22 @@ pub fn generate_tg_nonce(
let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap(); let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap();
let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap()); let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap());
debug!(
fast_mode = fast_mode,
tg_enc_key = %hex::encode(&tg_enc_key[..8]),
tg_dec_key = %hex::encode(&tg_dec_key[..8]),
"Generated TG nonce"
);
return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv); return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv);
} }
} }
/// Encrypt nonce for sending to Telegram /// Encrypt nonce for sending to Telegram
///
/// Only the part from PROTO_TAG_POS onwards is encrypted.
/// The encryption key is derived from enc_key_iv in the nonce itself.
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> { pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
// enc_key_iv is at nonce[8:56]
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
// Key for encrypting is just the first 32 bytes of enc_key_iv
let key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap(); let key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
let iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap()); let iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
let mut encryptor = AesCtr::new(&key, iv); let mut encryptor = AesCtr::new(&key, iv);
// Encrypt the entire nonce first, then take only the encrypted tail
let encrypted_full = encryptor.encrypt(nonce); let encrypted_full = encryptor.encrypt(nonce);
// Result: unencrypted head + encrypted tail
let mut result = nonce[..PROTO_TAG_POS].to_vec(); let mut result = nonce[..PROTO_TAG_POS].to_vec();
result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]); result.extend_from_slice(&encrypted_full[PROTO_TAG_POS..]);
trace!(
original = %hex::encode(&nonce[PROTO_TAG_POS..]),
encrypted = %hex::encode(&result[PROTO_TAG_POS..]),
"Encrypted nonce tail"
);
result result
} }
@@ -379,13 +318,12 @@ mod tests {
let client_dec_key = [0x42u8; 32]; let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128; let client_dec_iv = 12345u128;
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = let rng = SecureRandom::new();
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false); let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) =
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, &rng, false);
// Check length
assert_eq!(nonce.len(), HANDSHAKE_LEN); assert_eq!(nonce.len(), HANDSHAKE_LEN);
// Check proto tag is set
let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap(); let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap();
assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure)); assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure));
} }
@@ -395,17 +333,35 @@ mod tests {
let client_dec_key = [0x42u8; 32]; let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128; let client_dec_iv = 12345u128;
let rng = SecureRandom::new();
let (nonce, _, _, _, _) = let (nonce, _, _, _, _) =
generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, false); generate_tg_nonce(ProtoTag::Secure, &client_dec_key, client_dec_iv, &rng, false);
let encrypted = encrypt_tg_nonce(&nonce); let encrypted = encrypt_tg_nonce(&nonce);
assert_eq!(encrypted.len(), HANDSHAKE_LEN); assert_eq!(encrypted.len(), HANDSHAKE_LEN);
// First PROTO_TAG_POS bytes should be unchanged
assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]); assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]);
// Rest should be different (encrypted)
assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]); assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]);
} }
#[test]
fn test_handshake_success_zeroize_on_drop() {
let success = HandshakeSuccess {
user: "test".to_string(),
dc_idx: 2,
proto_tag: ProtoTag::Secure,
dec_key: [0xAA; 32],
dec_iv: 0xBBBBBBBB,
enc_key: [0xCC; 32],
enc_iv: 0xDDDDDDDD,
peer: "127.0.0.1:1234".parse().unwrap(),
is_tls: true,
};
assert_eq!(success.dec_key, [0xAA; 32]);
assert_eq!(success.enc_key, [0xCC; 32]);
drop(success);
// Drop impl zeroizes key material without panic
}
} }

View File

@@ -1,35 +1,76 @@
//! Masking - forward unrecognized traffic to mask host //! Masking - forward unrecognized traffic to mask host
use std::time::Duration; use std::time::Duration;
use std::str;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use tokio::time::timeout; use tokio::time::timeout;
use tracing::debug; use tracing::debug;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::transport::set_linger_zero;
const MASK_TIMEOUT: Duration = Duration::from_secs(5); const MASK_TIMEOUT: Duration = Duration::from_secs(5);
/// Maximum duration for the entire masking relay.
/// Limits resource consumption from slow-loris attacks and port scanners.
const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60);
const MASK_BUFFER_SIZE: usize = 8192; const MASK_BUFFER_SIZE: usize = 8192;
/// Detect client type based on initial data
fn detect_client_type(data: &[u8]) -> &'static str {
// Check for HTTP request
if data.len() > 4 {
if data.starts_with(b"GET ") || data.starts_with(b"POST") ||
data.starts_with(b"HEAD") || data.starts_with(b"PUT ") ||
data.starts_with(b"DELETE") || data.starts_with(b"OPTIONS") {
return "HTTP";
}
}
// Check for TLS ClientHello (0x16 = handshake, 0x03 0x01-0x03 = TLS version)
if data.len() > 3 && data[0] == 0x16 && data[1] == 0x03 {
return "TLS-scanner";
}
// Check for SSH
if data.starts_with(b"SSH-") {
return "SSH";
}
// Port scanner (very short data)
if data.len() < 10 {
return "port-scanner";
}
"unknown"
}
/// Handle a bad client by forwarding to mask host /// Handle a bad client by forwarding to mask host
pub async fn handle_bad_client( pub async fn handle_bad_client<R, W>(
client: TcpStream, mut reader: R,
mut writer: W,
initial_data: &[u8], initial_data: &[u8],
config: &ProxyConfig, config: &ProxyConfig,
) { )
if !config.mask { where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
if !config.censorship.mask {
// Masking disabled, just consume data // Masking disabled, just consume data
consume_client_data(client).await; consume_client_data(reader).await;
return; return;
} }
let mask_host = config.mask_host.as_deref() let client_type = detect_client_type(initial_data);
.unwrap_or(&config.tls_domain);
let mask_port = config.mask_port; let mask_host = config.censorship.mask_host.as_deref()
.unwrap_or(&config.censorship.tls_domain);
let mask_port = config.censorship.mask_port;
debug!( debug!(
client_type = client_type,
host = %mask_host, host = %mask_host,
port = mask_port, port = mask_port,
data_len = initial_data.len(),
"Forwarding bad client to mask host" "Forwarding bad client to mask host"
); );
@@ -40,33 +81,32 @@ pub async fn handle_bad_client(
TcpStream::connect(&mask_addr) TcpStream::connect(&mask_addr)
).await; ).await;
let mut mask_stream = match connect_result { let mask_stream = match connect_result {
Ok(Ok(s)) => s, Ok(Ok(s)) => s,
Ok(Err(e)) => { Ok(Err(e)) => {
debug!(error = %e, "Failed to connect to mask host"); debug!(error = %e, "Failed to connect to mask host");
consume_client_data(client).await; consume_client_data(reader).await;
return; return;
} }
Err(_) => { Err(_) => {
debug!("Timeout connecting to mask host"); debug!("Timeout connecting to mask host");
consume_client_data(client).await; consume_client_data(reader).await;
return; return;
} }
}; };
let (mut mask_read, mut mask_write) = mask_stream.into_split();
// Send initial data to mask host // Send initial data to mask host
if mask_stream.write_all(initial_data).await.is_err() { if mask_write.write_all(initial_data).await.is_err() {
return; return;
} }
// Relay traffic // Relay traffic
let (mut client_read, mut client_write) = client.into_split();
let (mut mask_read, mut mask_write) = mask_stream.into_split();
let c2m = tokio::spawn(async move { let c2m = tokio::spawn(async move {
let mut buf = vec![0u8; MASK_BUFFER_SIZE]; let mut buf = vec![0u8; MASK_BUFFER_SIZE];
loop { loop {
match client_read.read(&mut buf).await { match reader.read(&mut buf).await {
Ok(0) | Err(_) => { Ok(0) | Err(_) => {
let _ = mask_write.shutdown().await; let _ = mask_write.shutdown().await;
break; break;
@@ -85,11 +125,11 @@ pub async fn handle_bad_client(
loop { loop {
match mask_read.read(&mut buf).await { match mask_read.read(&mut buf).await {
Ok(0) | Err(_) => { Ok(0) | Err(_) => {
let _ = client_write.shutdown().await; let _ = writer.shutdown().await;
break; break;
} }
Ok(n) => { Ok(n) => {
if client_write.write_all(&buf[..n]).await.is_err() { if writer.write_all(&buf[..n]).await.is_err() {
break; break;
} }
} }
@@ -105,9 +145,9 @@ pub async fn handle_bad_client(
} }
/// Just consume all data from client without responding /// Just consume all data from client without responding
async fn consume_client_data(mut client: TcpStream) { async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R) {
let mut buf = vec![0u8; MASK_BUFFER_SIZE]; let mut buf = vec![0u8; MASK_BUFFER_SIZE];
while let Ok(n) = client.read(&mut buf).await { while let Ok(n) = reader.read(&mut buf).await {
if n == 0 { if n == 0 {
break; break;
} }

View File

@@ -1,30 +1,320 @@
//! Bidirectional Relay //! Bidirectional Relay — poll-based, no head-of-line blocking
//!
//! ## What changed and why
//!
//! Previous implementation used a single-task `select! { biased; ... }` loop
//! where each branch called `write_all()`. This caused head-of-line blocking:
//! while `write_all()` waited for a slow writer (e.g. client on 3G downloading
//! media), the entire loop was blocked — the other direction couldn't make progress.
//!
//! Symptoms observed in production:
//! - Media loading at ~8 KB/s despite fast server connection
//! - Stop-and-go pattern with 50500ms gaps between chunks
//! - `biased` select starving S→C direction
//! - Some users unable to load media at all
//!
//! ## New architecture
//!
//! Uses `tokio::io::copy_bidirectional` which polls both directions concurrently
//! in a single task via non-blocking `poll_read` / `poll_write` calls:
//!
//! Old (select! + write_all — BLOCKING):
//!
//! loop {
//! select! {
//! biased;
//! data = client.read() => { server.write_all(data).await; } ← BLOCKS here
//! data = server.read() => { client.write_all(data).await; } ← can't run
//! }
//! }
//!
//! New (copy_bidirectional — CONCURRENT):
//!
//! poll(cx) {
//! // Both directions polled in the same poll cycle
//! C→S: poll_read(client) → poll_write(server) // non-blocking
//! S→C: poll_read(server) → poll_write(client) // non-blocking
//! // If one writer is Pending, the other direction still progresses
//! }
//!
//! Benefits:
//! - No head-of-line blocking: slow client download doesn't block uploads
//! - No biased starvation: fair polling of both directions
//! - Proper flush: `copy_bidirectional` calls `poll_flush` when reader stalls,
//! so CryptoWriter's pending ciphertext is always drained (fixes "stuck at 95%")
//! - No deadlock risk: old write_all could deadlock when both TCP buffers filled;
//! poll-based approach lets TCP flow control work correctly
//!
//! Stats tracking:
//! - `StatsIo` wraps client side, intercepts `poll_read` / `poll_write`
//! - `poll_read` on client = C→S (client sending) → `octets_from`, `msgs_from`
//! - `poll_write` on client = S→C (to client) → `octets_to`, `msgs_to`
//! - `SharedCounters` (atomics) let the watchdog read stats without locking
use std::io;
use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::task::{Context, Poll};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::{debug, trace, warn, info}; use tracing::{debug, trace, warn};
use crate::error::Result; use crate::error::Result;
use crate::stats::Stats; use crate::stats::Stats;
use std::sync::atomic::{AtomicU64, Ordering}; use crate::stream::BufferPool;
// CHANGED: Reduced from 128KB to 16KB to match TLS record size and prevent bufferbloat. // ============= Constants =============
// This is critical for iOS clients to maintain proper TCP flow control during uploads.
const BUFFER_SIZE: usize = 16384;
// Activity timeout for iOS compatibility (30 minutes) /// Activity timeout for iOS compatibility.
// iOS does not support TCP_USER_TIMEOUT, so we implement application-level timeout ///
const ACTIVITY_TIMEOUT_SECS: u64 = 1800; /// iOS keeps Telegram connections alive in background for up to 30 minutes.
/// Closing earlier causes unnecessary reconnects and handshake overhead.
const ACTIVITY_TIMEOUT: Duration = Duration::from_secs(1800);
/// Relay data bidirectionally between client and server /// Watchdog check interval — also used for periodic rate logging.
///
/// 10 seconds gives responsive timeout detection (±10s accuracy)
/// without measurable overhead from atomic reads.
const WATCHDOG_INTERVAL: Duration = Duration::from_secs(10);
// ============= CombinedStream =============
/// Combines separate read and write halves into a single bidirectional stream.
///
/// `copy_bidirectional` requires `AsyncRead + AsyncWrite` on each side,
/// but the handshake layer produces split reader/writer pairs
/// (e.g. `CryptoReader<FakeTlsReader<OwnedReadHalf>>` + `CryptoWriter<...>`).
///
/// This wrapper reunifies them with zero overhead — each trait method
/// delegates directly to the corresponding half. No buffering, no copies.
///
/// Safety: `poll_read` only touches `reader`, `poll_write` only touches `writer`,
/// so there's no aliasing even though both are called on the same `&mut self`.
struct CombinedStream<R, W> {
reader: R,
writer: W,
}
impl<R, W> CombinedStream<R, W> {
fn new(reader: R, writer: W) -> Self {
Self { reader, writer }
}
}
impl<R: AsyncRead + Unpin, W: Unpin> AsyncRead for CombinedStream<R, W> {
#[inline]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().reader).poll_read(cx, buf)
}
}
impl<R: Unpin, W: AsyncWrite + Unpin> AsyncWrite for CombinedStream<R, W> {
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.get_mut().writer).poll_write(cx, buf)
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().writer).poll_flush(cx)
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().writer).poll_shutdown(cx)
}
}
// ============= SharedCounters =============
/// Atomic counters shared between the relay (via StatsIo) and the watchdog task.
///
/// Using `Relaxed` ordering is sufficient because:
/// - Counters are monotonically increasing (no ABA problem)
/// - Slight staleness in watchdog reads is harmless (±10s check interval anyway)
/// - No ordering dependencies between different counters
struct SharedCounters {
/// Bytes read from client (C→S direction)
c2s_bytes: AtomicU64,
/// Bytes written to client (S→C direction)
s2c_bytes: AtomicU64,
/// Number of poll_read completions (≈ C→S chunks)
c2s_ops: AtomicU64,
/// Number of poll_write completions (≈ S→C chunks)
s2c_ops: AtomicU64,
/// Milliseconds since relay epoch of last I/O activity
last_activity_ms: AtomicU64,
}
impl SharedCounters {
fn new() -> Self {
Self {
c2s_bytes: AtomicU64::new(0),
s2c_bytes: AtomicU64::new(0),
c2s_ops: AtomicU64::new(0),
s2c_ops: AtomicU64::new(0),
last_activity_ms: AtomicU64::new(0),
}
}
/// Record activity at this instant.
#[inline]
fn touch(&self, now: Instant, epoch: Instant) {
let ms = now.duration_since(epoch).as_millis() as u64;
self.last_activity_ms.store(ms, Ordering::Relaxed);
}
/// How long since last recorded activity.
fn idle_duration(&self, now: Instant, epoch: Instant) -> Duration {
let last_ms = self.last_activity_ms.load(Ordering::Relaxed);
let now_ms = now.duration_since(epoch).as_millis() as u64;
Duration::from_millis(now_ms.saturating_sub(last_ms))
}
}
// ============= StatsIo =============
/// Transparent I/O wrapper that tracks per-user statistics and activity.
///
/// Wraps the **client** side of the relay. Direction mapping:
///
/// | poll method | direction | stats updated |
/// |-------------|-----------|--------------------------------------|
/// | `poll_read` | C→S | `octets_from`, `msgs_from`, counters |
/// | `poll_write` | S→C | `octets_to`, `msgs_to`, counters |
///
/// Both update the shared activity timestamp for the watchdog.
///
/// Note on message counts: the original code counted one `read()`/`write_all()`
/// as one "message". Here we count `poll_read`/`poll_write` completions instead.
/// Byte counts are identical; op counts may differ slightly due to different
/// internal buffering in `copy_bidirectional`. This is fine for monitoring.
struct StatsIo<S> {
inner: S,
counters: Arc<SharedCounters>,
stats: Arc<Stats>,
user: String,
epoch: Instant,
}
impl<S> StatsIo<S> {
fn new(
inner: S,
counters: Arc<SharedCounters>,
stats: Arc<Stats>,
user: String,
epoch: Instant,
) -> Self {
// Mark initial activity so the watchdog doesn't fire before data flows
counters.touch(Instant::now(), epoch);
Self { inner, counters, stats, user, epoch }
}
}
impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
let before = buf.filled().len();
match Pin::new(&mut this.inner).poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
let n = buf.filled().len() - before;
if n > 0 {
// C→S: client sent data
this.counters.c2s_bytes.fetch_add(n as u64, Ordering::Relaxed);
this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed);
this.counters.touch(Instant::now(), this.epoch);
this.stats.add_user_octets_from(&this.user, n as u64);
this.stats.increment_user_msgs_from(&this.user);
trace!(user = %this.user, bytes = n, "C->S");
}
Poll::Ready(Ok(()))
}
other => other,
}
}
}
impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
match Pin::new(&mut this.inner).poll_write(cx, buf) {
Poll::Ready(Ok(n)) => {
if n > 0 {
// S→C: data written to client
this.counters.s2c_bytes.fetch_add(n as u64, Ordering::Relaxed);
this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed);
this.counters.touch(Instant::now(), this.epoch);
this.stats.add_user_octets_to(&this.user, n as u64);
this.stats.increment_user_msgs_to(&this.user);
trace!(user = %this.user, bytes = n, "S->C");
}
Poll::Ready(Ok(n))
}
other => other,
}
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_flush(cx)
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
}
}
// ============= Relay =============
/// Relay data bidirectionally between client and server.
///
/// Uses `tokio::io::copy_bidirectional` for concurrent, non-blocking data transfer.
///
/// ## API compatibility
///
/// Signature is identical to the previous implementation. The `_buffer_pool`
/// parameter is retained for call-site compatibility — `copy_bidirectional`
/// manages its own internal buffers (8 KB per direction).
///
/// ## Guarantees preserved
///
/// - Activity timeout: 30 minutes of inactivity → clean shutdown
/// - Per-user stats: bytes and ops counted per direction
/// - Periodic rate logging: every 10 seconds when active
/// - Clean shutdown: both write sides are shut down on exit
/// - Error propagation: I/O errors are returned as `ProxyError::Io`
pub async fn relay_bidirectional<CR, CW, SR, SW>( pub async fn relay_bidirectional<CR, CW, SR, SW>(
mut client_reader: CR, client_reader: CR,
mut client_writer: CW, client_writer: CW,
mut server_reader: SR, server_reader: SR,
mut server_writer: SW, server_writer: SW,
user: &str, user: &str,
stats: Arc<Stats>, stats: Arc<Stats>,
_buffer_pool: Arc<BufferPool>,
) -> Result<()> ) -> Result<()>
where where
CR: AsyncRead + Unpin + Send + 'static, CR: AsyncRead + Unpin + Send + 'static,
@@ -32,218 +322,145 @@ where
SR: AsyncRead + Unpin + Send + 'static, SR: AsyncRead + Unpin + Send + 'static,
SW: AsyncWrite + Unpin + Send + 'static, SW: AsyncWrite + Unpin + Send + 'static,
{ {
let user_c2s = user.to_string(); let epoch = Instant::now();
let user_s2c = user.to_string(); let counters = Arc::new(SharedCounters::new());
let user_owned = user.to_string();
// Используем Arc::clone вместо stats.clone() // ── Combine split halves into bidirectional streams ──────────────
let stats_c2s = Arc::clone(&stats); let client_combined = CombinedStream::new(client_reader, client_writer);
let stats_s2c = Arc::clone(&stats); let mut server = CombinedStream::new(server_reader, server_writer);
let c2s_bytes = Arc::new(AtomicU64::new(0)); // Wrap client with stats/activity tracking
let s2c_bytes = Arc::new(AtomicU64::new(0)); let mut client = StatsIo::new(
let c2s_bytes_clone = Arc::clone(&c2s_bytes); client_combined,
let s2c_bytes_clone = Arc::clone(&s2c_bytes); Arc::clone(&counters),
Arc::clone(&stats),
user_owned.clone(),
epoch,
);
// Activity timeout for iOS compatibility // ── Watchdog: activity timeout + periodic rate logging ──────────
let activity_timeout = Duration::from_secs(ACTIVITY_TIMEOUT_SECS); let wd_counters = Arc::clone(&counters);
let wd_user = user_owned.clone();
// Client -> Server task with activity timeout let watchdog = async {
let c2s = tokio::spawn(async move { let mut prev_c2s: u64 = 0;
let mut buf = vec![0u8; BUFFER_SIZE]; let mut prev_s2c: u64 = 0;
let mut total_bytes = 0u64;
let mut msg_count = 0u64;
let mut last_activity = Instant::now();
let mut last_log = Instant::now();
loop { loop {
// Read with timeout to prevent infinite hang on iOS tokio::time::sleep(WATCHDOG_INTERVAL).await;
let read_result = tokio::time::timeout(
activity_timeout,
client_reader.read(&mut buf)
).await;
match read_result { let now = Instant::now();
// Timeout - no activity for too long let idle = wd_counters.idle_duration(now, epoch);
Err(_) => {
// ── Activity timeout ────────────────────────────────────
if idle >= ACTIVITY_TIMEOUT {
let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed);
let s2c = wd_counters.s2c_bytes.load(Ordering::Relaxed);
warn!( warn!(
user = %user_c2s, user = %wd_user,
total_bytes = total_bytes, c2s_bytes = c2s,
msgs = msg_count, s2c_bytes = s2c,
idle_secs = last_activity.elapsed().as_secs(), idle_secs = idle.as_secs(),
"Activity timeout (C->S) - no data received" "Activity timeout"
); );
let _ = server_writer.shutdown().await; return; // Causes select! to cancel copy_bidirectional
break;
} }
// Read successful // ── Periodic rate logging ───────────────────────────────
Ok(Ok(0)) => { let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed);
let s2c = wd_counters.s2c_bytes.load(Ordering::Relaxed);
let c2s_delta = c2s - prev_c2s;
let s2c_delta = s2c - prev_s2c;
if c2s_delta > 0 || s2c_delta > 0 {
let secs = WATCHDOG_INTERVAL.as_secs_f64();
debug!( debug!(
user = %user_c2s, user = %wd_user,
total_bytes = total_bytes, c2s_kbps = (c2s_delta as f64 / secs / 1024.0) as u64,
msgs = msg_count, s2c_kbps = (s2c_delta as f64 / secs / 1024.0) as u64,
"Client closed connection (C->S)" c2s_total = c2s,
s2c_total = s2c,
"Relay active"
); );
let _ = server_writer.shutdown().await;
break;
} }
Ok(Ok(n)) => { prev_c2s = c2s;
total_bytes += n as u64; prev_s2c = s2c;
msg_count += 1;
last_activity = Instant::now();
c2s_bytes_clone.store(total_bytes, Ordering::Relaxed);
stats_c2s.add_user_octets_from(&user_c2s, n as u64);
stats_c2s.increment_user_msgs_from(&user_c2s);
trace!(
user = %user_c2s,
bytes = n,
total = total_bytes,
data_preview = %hex::encode(&buf[..n.min(32)]),
"C->S data"
);
// Log activity every 10 seconds for large transfers
if last_log.elapsed() > Duration::from_secs(10) {
let rate = total_bytes as f64 / last_log.elapsed().as_secs_f64();
info!(
user = %user_c2s,
total_bytes = total_bytes,
msgs = msg_count,
rate_kbps = (rate / 1024.0) as u64,
"C->S transfer in progress"
);
last_log = Instant::now();
} }
};
if let Err(e) = server_writer.write_all(&buf[..n]).await { // ── Run bidirectional copy + watchdog concurrently ───────────────
debug!(user = %user_c2s, error = %e, "Failed to write to server"); //
break; // copy_bidirectional polls both directions in the same poll() call:
} // C→S: poll_read(client/StatsIo) → poll_write(server)
if let Err(e) = server_writer.flush().await { // S→C: poll_read(server) → poll_write(client/StatsIo)
debug!(user = %user_c2s, error = %e, "Failed to flush to server"); //
break; // When one direction's writer returns Pending, the other direction
} // continues — no head-of-line blocking.
} //
// When the watchdog fires, select! drops the copy future,
// releasing the &mut borrows on client and server.
let copy_result = tokio::select! {
result = copy_bidirectional(&mut client, &mut server) => Some(result),
_ = watchdog => None, // Activity timeout — cancel relay
};
Ok(Err(e)) => { // ── Clean shutdown ──────────────────────────────────────────────
debug!(user = %user_c2s, error = %e, total_bytes = total_bytes, "Client read error"); // After select!, the losing future is dropped, borrows released.
break; // Shut down both write sides for clean TCP FIN.
} let _ = client.shutdown().await;
} let _ = server.shutdown().await;
}
});
// Server -> Client task with activity timeout // ── Final logging ───────────────────────────────────────────────
let s2c = tokio::spawn(async move { let c2s_ops = counters.c2s_ops.load(Ordering::Relaxed);
let mut buf = vec![0u8; BUFFER_SIZE]; let s2c_ops = counters.s2c_ops.load(Ordering::Relaxed);
let mut total_bytes = 0u64; let duration = epoch.elapsed();
let mut msg_count = 0u64;
let mut last_activity = Instant::now();
let mut last_log = Instant::now();
loop { match copy_result {
// Read with timeout to prevent infinite hang on iOS Some(Ok((c2s, s2c))) => {
let read_result = tokio::time::timeout( // Normal completion — one side closed the connection
activity_timeout,
server_reader.read(&mut buf)
).await;
match read_result {
// Timeout - no activity for too long
Err(_) => {
warn!(
user = %user_s2c,
total_bytes = total_bytes,
msgs = msg_count,
idle_secs = last_activity.elapsed().as_secs(),
"Activity timeout (S->C) - no data received"
);
let _ = client_writer.shutdown().await;
break;
}
// Read successful
Ok(Ok(0)) => {
debug!( debug!(
user = %user_s2c, user = %user_owned,
total_bytes = total_bytes, c2s_bytes = c2s,
msgs = msg_count, s2c_bytes = s2c,
"Server closed connection (S->C)" c2s_msgs = c2s_ops,
); s2c_msgs = s2c_ops,
let _ = client_writer.shutdown().await; duration_secs = duration.as_secs(),
break;
}
Ok(Ok(n)) => {
total_bytes += n as u64;
msg_count += 1;
last_activity = Instant::now();
s2c_bytes_clone.store(total_bytes, Ordering::Relaxed);
stats_s2c.add_user_octets_to(&user_s2c, n as u64);
stats_s2c.increment_user_msgs_to(&user_s2c);
trace!(
user = %user_s2c,
bytes = n,
total = total_bytes,
data_preview = %hex::encode(&buf[..n.min(32)]),
"S->C data"
);
// Log activity every 10 seconds for large transfers
if last_log.elapsed() > Duration::from_secs(10) {
let rate = total_bytes as f64 / last_log.elapsed().as_secs_f64();
info!(
user = %user_s2c,
total_bytes = total_bytes,
msgs = msg_count,
rate_kbps = (rate / 1024.0) as u64,
"S->C transfer in progress"
);
last_log = Instant::now();
}
if let Err(e) = client_writer.write_all(&buf[..n]).await {
debug!(user = %user_s2c, error = %e, "Failed to write to client");
break;
}
if let Err(e) = client_writer.flush().await {
debug!(user = %user_s2c, error = %e, "Failed to flush to client");
break;
}
}
Ok(Err(e)) => {
debug!(user = %user_s2c, error = %e, total_bytes = total_bytes, "Server read error");
break;
}
}
}
});
// Wait for either direction to complete
tokio::select! {
result = c2s => {
if let Err(e) = result {
warn!(error = %e, "C->S task panicked");
}
}
result = s2c => {
if let Err(e) = result {
warn!(error = %e, "S->C task panicked");
}
}
}
debug!(
c2s_bytes = c2s_bytes.load(Ordering::Relaxed),
s2c_bytes = s2c_bytes.load(Ordering::Relaxed),
"Relay finished" "Relay finished"
); );
Ok(()) Ok(())
}
Some(Err(e)) => {
// I/O error in one of the directions
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
debug!(
user = %user_owned,
c2s_bytes = c2s,
s2c_bytes = s2c,
c2s_msgs = c2s_ops,
s2c_msgs = s2c_ops,
duration_secs = duration.as_secs(),
error = %e,
"Relay error"
);
Err(e.into())
}
None => {
// Activity timeout (watchdog fired)
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
debug!(
user = %user_owned,
c2s_bytes = c2s,
s2c_bytes = s2c,
c2s_msgs = c2s_ops,
s2c_msgs = s2c_ops,
duration_secs = duration.as_secs(),
"Relay finished (activity timeout)"
);
Ok(())
}
}
} }

View File

@@ -1,29 +1,28 @@
//! Statistics //! Statistics and replay protection
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::{Instant, Duration};
use dashmap::DashMap; use dashmap::DashMap;
use parking_lot::RwLock; use parking_lot::Mutex;
use lru::LruCache; use lru::LruCache;
use std::num::NonZeroUsize; use std::num::NonZeroUsize;
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
use std::collections::VecDeque;
use tracing::debug;
// ============= Stats =============
/// Thread-safe statistics
#[derive(Default)] #[derive(Default)]
pub struct Stats { pub struct Stats {
// Global counters
connects_all: AtomicU64, connects_all: AtomicU64,
connects_bad: AtomicU64, connects_bad: AtomicU64,
handshake_timeouts: AtomicU64, handshake_timeouts: AtomicU64,
// Per-user stats
user_stats: DashMap<String, UserStats>, user_stats: DashMap<String, UserStats>,
start_time: parking_lot::RwLock<Option<Instant>>,
// Start time
start_time: RwLock<Option<Instant>>,
} }
/// Per-user statistics
#[derive(Default)] #[derive(Default)]
pub struct UserStats { pub struct UserStats {
pub connects: AtomicU64, pub connects: AtomicU64,
@@ -41,42 +40,20 @@ impl Stats {
stats stats
} }
// Global stats pub fn increment_connects_all(&self) { self.connects_all.fetch_add(1, Ordering::Relaxed); }
pub fn increment_connects_all(&self) { pub fn increment_connects_bad(&self) { self.connects_bad.fetch_add(1, Ordering::Relaxed); }
self.connects_all.fetch_add(1, Ordering::Relaxed); pub fn increment_handshake_timeouts(&self) { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); }
} pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) }
pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) }
pub fn increment_connects_bad(&self) {
self.connects_bad.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_handshake_timeouts(&self) {
self.handshake_timeouts.fetch_add(1, Ordering::Relaxed);
}
pub fn get_connects_all(&self) -> u64 {
self.connects_all.load(Ordering::Relaxed)
}
pub fn get_connects_bad(&self) -> u64 {
self.connects_bad.load(Ordering::Relaxed)
}
// User stats
pub fn increment_user_connects(&self, user: &str) { pub fn increment_user_connects(&self, user: &str) {
self.user_stats self.user_stats.entry(user.to_string()).or_default()
.entry(user.to_string()) .connects.fetch_add(1, Ordering::Relaxed);
.or_default()
.connects
.fetch_add(1, Ordering::Relaxed);
} }
pub fn increment_user_curr_connects(&self, user: &str) { pub fn increment_user_curr_connects(&self, user: &str) {
self.user_stats self.user_stats.entry(user.to_string()).or_default()
.entry(user.to_string()) .curr_connects.fetch_add(1, Ordering::Relaxed);
.or_default()
.curr_connects
.fetch_add(1, Ordering::Relaxed);
} }
pub fn decrement_user_curr_connects(&self, user: &str) { pub fn decrement_user_curr_connects(&self, user: &str) {
@@ -86,47 +63,33 @@ impl Stats {
} }
pub fn get_user_curr_connects(&self, user: &str) -> u64 { pub fn get_user_curr_connects(&self, user: &str) -> u64 {
self.user_stats self.user_stats.get(user)
.get(user)
.map(|s| s.curr_connects.load(Ordering::Relaxed)) .map(|s| s.curr_connects.load(Ordering::Relaxed))
.unwrap_or(0) .unwrap_or(0)
} }
pub fn add_user_octets_from(&self, user: &str, bytes: u64) { pub fn add_user_octets_from(&self, user: &str, bytes: u64) {
self.user_stats self.user_stats.entry(user.to_string()).or_default()
.entry(user.to_string()) .octets_from_client.fetch_add(bytes, Ordering::Relaxed);
.or_default()
.octets_from_client
.fetch_add(bytes, Ordering::Relaxed);
} }
pub fn add_user_octets_to(&self, user: &str, bytes: u64) { pub fn add_user_octets_to(&self, user: &str, bytes: u64) {
self.user_stats self.user_stats.entry(user.to_string()).or_default()
.entry(user.to_string()) .octets_to_client.fetch_add(bytes, Ordering::Relaxed);
.or_default()
.octets_to_client
.fetch_add(bytes, Ordering::Relaxed);
} }
pub fn increment_user_msgs_from(&self, user: &str) { pub fn increment_user_msgs_from(&self, user: &str) {
self.user_stats self.user_stats.entry(user.to_string()).or_default()
.entry(user.to_string()) .msgs_from_client.fetch_add(1, Ordering::Relaxed);
.or_default()
.msgs_from_client
.fetch_add(1, Ordering::Relaxed);
} }
pub fn increment_user_msgs_to(&self, user: &str) { pub fn increment_user_msgs_to(&self, user: &str) {
self.user_stats self.user_stats.entry(user.to_string()).or_default()
.entry(user.to_string()) .msgs_to_client.fetch_add(1, Ordering::Relaxed);
.or_default()
.msgs_to_client
.fetch_add(1, Ordering::Relaxed);
} }
pub fn get_user_total_octets(&self, user: &str) -> u64 { pub fn get_user_total_octets(&self, user: &str) -> u64 {
self.user_stats self.user_stats.get(user)
.get(user)
.map(|s| { .map(|s| {
s.octets_from_client.load(Ordering::Relaxed) + s.octets_from_client.load(Ordering::Relaxed) +
s.octets_to_client.load(Ordering::Relaxed) s.octets_to_client.load(Ordering::Relaxed)
@@ -141,37 +104,209 @@ impl Stats {
} }
} }
// Arc<Stats> Hightech Stats :D // ============= Replay Checker =============
/// Replay attack checker using LRU cache
pub struct ReplayChecker { pub struct ReplayChecker {
handshakes: RwLock<LruCache<Vec<u8>, ()>>, shards: Vec<Mutex<ReplayShard>>,
tls_digests: RwLock<LruCache<Vec<u8>, ()>>, shard_mask: usize,
window: Duration,
checks: AtomicU64,
hits: AtomicU64,
additions: AtomicU64,
cleanups: AtomicU64,
}
struct ReplayEntry {
seen_at: Instant,
seq: u64,
}
struct ReplayShard {
cache: LruCache<Box<[u8]>, ReplayEntry>,
queue: VecDeque<(Instant, Box<[u8]>, u64)>,
seq_counter: u64,
}
impl ReplayShard {
fn new(cap: NonZeroUsize) -> Self {
Self {
cache: LruCache::new(cap),
queue: VecDeque::with_capacity(cap.get()),
seq_counter: 0,
}
}
fn next_seq(&mut self) -> u64 {
self.seq_counter += 1;
self.seq_counter
}
fn cleanup(&mut self, now: Instant, window: Duration) {
if window.is_zero() {
return;
}
let cutoff = now.checked_sub(window).unwrap_or(now);
while let Some((ts, _, _)) = self.queue.front() {
if *ts >= cutoff {
break;
}
let (_, key, queue_seq) = self.queue.pop_front().unwrap();
// Use key.as_ref() to get &[u8] — avoids Borrow<Q> ambiguity
// between Borrow<[u8]> and Borrow<Box<[u8]>>
if let Some(entry) = self.cache.peek(key.as_ref()) {
if entry.seq == queue_seq {
self.cache.pop(key.as_ref());
}
}
}
}
fn check(&mut self, key: &[u8], now: Instant, window: Duration) -> bool {
self.cleanup(now, window);
// key is &[u8], resolves Q=[u8] via Box<[u8]>: Borrow<[u8]>
self.cache.get(key).is_some()
}
fn add(&mut self, key: &[u8], now: Instant, window: Duration) {
self.cleanup(now, window);
let seq = self.next_seq();
let boxed_key: Box<[u8]> = key.into();
self.cache.put(boxed_key.clone(), ReplayEntry { seen_at: now, seq });
self.queue.push_back((now, boxed_key, seq));
}
fn len(&self) -> usize {
self.cache.len()
}
} }
impl ReplayChecker { impl ReplayChecker {
pub fn new(capacity: usize) -> Self { pub fn new(total_capacity: usize, window: Duration) -> Self {
let cap = NonZeroUsize::new(capacity.max(1)).unwrap(); let num_shards = 64;
let shard_capacity = (total_capacity / num_shards).max(1);
let cap = NonZeroUsize::new(shard_capacity).unwrap();
let mut shards = Vec::with_capacity(num_shards);
for _ in 0..num_shards {
shards.push(Mutex::new(ReplayShard::new(cap)));
}
Self { Self {
handshakes: RwLock::new(LruCache::new(cap)), shards,
tls_digests: RwLock::new(LruCache::new(cap)), shard_mask: num_shards - 1,
window,
checks: AtomicU64::new(0),
hits: AtomicU64::new(0),
additions: AtomicU64::new(0),
cleanups: AtomicU64::new(0),
} }
} }
pub fn check_handshake(&self, data: &[u8]) -> bool { fn get_shard_idx(&self, key: &[u8]) -> usize {
self.handshakes.read().contains(&data.to_vec()) let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
(hasher.finish() as usize) & self.shard_mask
} }
pub fn add_handshake(&self, data: &[u8]) { fn check(&self, data: &[u8]) -> bool {
self.handshakes.write().put(data.to_vec(), ()); self.checks.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
let mut shard = self.shards[idx].lock();
let found = shard.check(data, Instant::now(), self.window);
if found {
self.hits.fetch_add(1, Ordering::Relaxed);
}
found
} }
pub fn check_tls_digest(&self, data: &[u8]) -> bool { fn add(&self, data: &[u8]) {
self.tls_digests.read().contains(&data.to_vec()) self.additions.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
let mut shard = self.shards[idx].lock();
shard.add(data, Instant::now(), self.window);
} }
pub fn add_tls_digest(&self, data: &[u8]) { pub fn check_handshake(&self, data: &[u8]) -> bool { self.check(data) }
self.tls_digests.write().put(data.to_vec(), ()); pub fn add_handshake(&self, data: &[u8]) { self.add(data) }
pub fn check_tls_digest(&self, data: &[u8]) -> bool { self.check(data) }
pub fn add_tls_digest(&self, data: &[u8]) { self.add(data) }
pub fn stats(&self) -> ReplayStats {
let mut total_entries = 0;
let mut total_queue_len = 0;
for shard in &self.shards {
let s = shard.lock();
total_entries += s.cache.len();
total_queue_len += s.queue.len();
}
ReplayStats {
total_entries,
total_queue_len,
total_checks: self.checks.load(Ordering::Relaxed),
total_hits: self.hits.load(Ordering::Relaxed),
total_additions: self.additions.load(Ordering::Relaxed),
total_cleanups: self.cleanups.load(Ordering::Relaxed),
num_shards: self.shards.len(),
window_secs: self.window.as_secs(),
}
}
pub async fn run_periodic_cleanup(&self) {
let interval = if self.window.as_secs() > 60 {
Duration::from_secs(30)
} else {
Duration::from_secs(self.window.as_secs().max(1) / 2)
};
loop {
tokio::time::sleep(interval).await;
let now = Instant::now();
let mut cleaned = 0usize;
for shard_mutex in &self.shards {
let mut shard = shard_mutex.lock();
let before = shard.len();
shard.cleanup(now, self.window);
let after = shard.len();
cleaned += before.saturating_sub(after);
}
self.cleanups.fetch_add(1, Ordering::Relaxed);
if cleaned > 0 {
debug!(cleaned = cleaned, "Replay checker: periodic cleanup");
}
}
}
}
#[derive(Debug, Clone)]
pub struct ReplayStats {
pub total_entries: usize,
pub total_queue_len: usize,
pub total_checks: u64,
pub total_hits: u64,
pub total_additions: u64,
pub total_cleanups: u64,
pub num_shards: usize,
pub window_secs: u64,
}
impl ReplayStats {
pub fn hit_rate(&self) -> f64 {
if self.total_checks == 0 { 0.0 }
else { (self.total_hits as f64 / self.total_checks as f64) * 100.0 }
}
pub fn ghost_ratio(&self) -> f64 {
if self.total_entries == 0 { 0.0 }
else { self.total_queue_len as f64 / self.total_entries as f64 }
} }
} }
@@ -182,42 +317,60 @@ mod tests {
#[test] #[test]
fn test_stats_shared_counters() { fn test_stats_shared_counters() {
let stats = Arc::new(Stats::new()); let stats = Arc::new(Stats::new());
stats.increment_connects_all();
// Симулируем использование из разных "задач" stats.increment_connects_all();
let stats1 = Arc::clone(&stats); stats.increment_connects_all();
let stats2 = Arc::clone(&stats);
stats1.increment_connects_all();
stats2.increment_connects_all();
stats1.increment_connects_all();
// Все инкременты должны быть видны
assert_eq!(stats.get_connects_all(), 3); assert_eq!(stats.get_connects_all(), 3);
} }
#[test] #[test]
fn test_user_stats_shared() { fn test_replay_checker_basic() {
let stats = Arc::new(Stats::new()); let checker = ReplayChecker::new(100, Duration::from_secs(60));
assert!(!checker.check_handshake(b"test1"));
let stats1 = Arc::clone(&stats); checker.add_handshake(b"test1");
let stats2 = Arc::clone(&stats); assert!(checker.check_handshake(b"test1"));
assert!(!checker.check_handshake(b"test2"));
stats1.add_user_octets_from("user1", 100);
stats2.add_user_octets_from("user1", 200);
stats1.add_user_octets_to("user1", 50);
assert_eq!(stats.get_user_total_octets("user1"), 350);
} }
#[test] #[test]
fn test_concurrent_user_connects() { fn test_replay_checker_duplicate_add() {
let stats = Arc::new(Stats::new()); let checker = ReplayChecker::new(100, Duration::from_secs(60));
checker.add_handshake(b"dup");
checker.add_handshake(b"dup");
assert!(checker.check_handshake(b"dup"));
}
stats.increment_user_curr_connects("user1"); #[test]
stats.increment_user_curr_connects("user1"); fn test_replay_checker_expiration() {
assert_eq!(stats.get_user_curr_connects("user1"), 2); let checker = ReplayChecker::new(100, Duration::from_millis(50));
checker.add_handshake(b"expire");
assert!(checker.check_handshake(b"expire"));
std::thread::sleep(Duration::from_millis(100));
assert!(!checker.check_handshake(b"expire"));
}
stats.decrement_user_curr_connects("user1"); #[test]
assert_eq!(stats.get_user_curr_connects("user1"), 1); fn test_replay_checker_stats() {
let checker = ReplayChecker::new(100, Duration::from_secs(60));
checker.add_handshake(b"k1");
checker.add_handshake(b"k2");
checker.check_handshake(b"k1");
checker.check_handshake(b"k3");
let stats = checker.stats();
assert_eq!(stats.total_additions, 2);
assert_eq!(stats.total_checks, 2);
assert_eq!(stats.total_hits, 1);
}
#[test]
fn test_replay_checker_many_keys() {
let checker = ReplayChecker::new(1000, Duration::from_secs(60));
for i in 0..500u32 {
checker.add(&i.to_le_bytes());
}
for i in 0..500u32 {
assert!(checker.check(&i.to_le_bytes()));
}
assert_eq!(checker.stats().total_entries, 500);
} }
} }

View File

@@ -45,6 +45,11 @@
//! - when upstream is Pending but pending still has room: accept `to_accept` bytes and //! - when upstream is Pending but pending still has room: accept `to_accept` bytes and
//! encrypt+append ciphertext directly into pending (in-place encryption of appended range) //! encrypt+append ciphertext directly into pending (in-place encryption of appended range)
//! Encrypted stream wrappers using AES-CTR
//!
//! This module provides stateful async stream wrappers that handle
//! encryption/decryption with proper partial read/write handling.
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use std::io::{self, ErrorKind, Result}; use std::io::{self, ErrorKind, Result};
use std::pin::Pin; use std::pin::Pin;
@@ -58,8 +63,9 @@ use super::state::{StreamState, YieldBuffer};
// ============= Constants ============= // ============= Constants =============
/// Maximum size for pending ciphertext buffer (bounded backpressure). /// Maximum size for pending ciphertext buffer (bounded backpressure).
/// 512 KiB tends to work well for mobile networks and avoids huge latency spikes. /// Reduced to 64KB to prevent bufferbloat on mobile networks.
const MAX_PENDING_WRITE: usize = 524_288; /// 512KB was causing high latency on 3G/LTE connections.
const MAX_PENDING_WRITE: usize = 64 * 1024;
/// Default read buffer capacity (reader mostly decrypts in-place into caller buffer). /// Default read buffer capacity (reader mostly decrypts in-place into caller buffer).
const DEFAULT_READ_CAPACITY: usize = 16 * 1024; const DEFAULT_READ_CAPACITY: usize = 16 * 1024;
@@ -99,22 +105,6 @@ impl StreamState for CryptoReaderState {
// ============= CryptoReader ============= // ============= CryptoReader =============
/// Reader that decrypts data using AES-CTR with proper state machine. /// Reader that decrypts data using AES-CTR with proper state machine.
///
/// This reader handles partial reads correctly by maintaining internal state
/// and never losing any data that has been read from upstream.
///
/// # State Machine
///
/// ┌──────────┐ read ┌──────────┐
/// │ Idle │ ------------> │ Yielding │
/// │ │ <------------ │ │
/// └──────────┘ drained └──────────┘
/// │ │
/// │ errors │
/// ▼ ▼
/// ┌──────────────────────────────────────┐
/// │ Poisoned │
/// └──────────────────────────────────────┘
pub struct CryptoReader<R> { pub struct CryptoReader<R> {
upstream: R, upstream: R,
decryptor: AesCtr, decryptor: AesCtr,
@@ -315,10 +305,6 @@ impl<R: AsyncRead + Unpin> CryptoReader<R> {
// ============= Pending Ciphertext ============= // ============= Pending Ciphertext =============
/// Pending ciphertext buffer with explicit position and strict max size. /// Pending ciphertext buffer with explicit position and strict max size.
///
/// - append plaintext then encrypt appended range in-place - one-touch copy, no extra Vec
/// - move ciphertext from scratch into pending without copying
/// - explicit compaction behavior for long-lived connections
#[derive(Debug)] #[derive(Debug)]
struct PendingCiphertext { struct PendingCiphertext {
buf: BytesMut, buf: BytesMut,
@@ -361,15 +347,13 @@ impl PendingCiphertext {
} }
// Compact when a large prefix was consumed. // Compact when a large prefix was consumed.
if self.pos >= 32 * 1024 { if self.pos >= 16 * 1024 {
let _ = self.buf.split_to(self.pos); let _ = self.buf.split_to(self.pos);
self.pos = 0; self.pos = 0;
} }
} }
/// Replace the entire pending ciphertext by moving `src` in (swap, no copy). /// Replace the entire pending ciphertext by moving `src` in (swap, no copy).
///
/// Precondition: src.len() <= max_len.
fn replace_with(&mut self, mut src: BytesMut) { fn replace_with(&mut self, mut src: BytesMut) {
debug_assert!(src.len() <= self.max_len); debug_assert!(src.len() <= self.max_len);
@@ -381,12 +365,6 @@ impl PendingCiphertext {
} }
/// Append plaintext and encrypt appended range in-place. /// Append plaintext and encrypt appended range in-place.
///
/// This is the high-throughput buffering path:
/// - copy plaintext into pending buffer
/// - encrypt only the newly appended bytes
///
/// CTR state advances by exactly plaintext.len().
fn push_encrypted(&mut self, encryptor: &mut AesCtr, plaintext: &[u8]) -> Result<()> { fn push_encrypted(&mut self, encryptor: &mut AesCtr, plaintext: &[u8]) -> Result<()> {
if plaintext.is_empty() { if plaintext.is_empty() {
return Ok(()); return Ok(());
@@ -444,21 +422,10 @@ impl StreamState for CryptoWriterState {
// ============= CryptoWriter ============= // ============= CryptoWriter =============
/// Writer that encrypts data using AES-CTR with correct async semantics. /// Writer that encrypts data using AES-CTR with correct async semantics.
///
/// - CTR state advances exactly by the number of bytes we report as written
/// - If upstream blocks, ciphertext is buffered/bounded
/// - Backpressure is applied when buffer is full
pub struct CryptoWriter<W> { pub struct CryptoWriter<W> {
upstream: W, upstream: W,
encryptor: AesCtr, encryptor: AesCtr,
state: CryptoWriterState, state: CryptoWriterState,
/// Scratch ciphertext for fast "write-through" path.
///
/// Flow:
/// - encrypt plaintext into scratch
/// - try upstream write
/// - if Pending/partial: move remainder into pending without re-encrypting
scratch: BytesMut, scratch: BytesMut,
} }
@@ -531,9 +498,6 @@ impl<W> CryptoWriter<W> {
} }
/// Select how many plaintext bytes can be accepted in buffering path /// Select how many plaintext bytes can be accepted in buffering path
///
/// Requirement: worst case - upstream pending, must buffer all ciphertext
/// for the accepted bytes
fn select_to_accept_for_buffering(state: &CryptoWriterState, buf_len: usize) -> usize { fn select_to_accept_for_buffering(state: &CryptoWriterState, buf_len: usize) -> usize {
if buf_len == 0 { if buf_len == 0 {
return 0; return 0;
@@ -557,11 +521,6 @@ impl<W> CryptoWriter<W> {
impl<W: AsyncWrite + Unpin> CryptoWriter<W> { impl<W: AsyncWrite + Unpin> CryptoWriter<W> {
/// Flush as much pending ciphertext as possible /// Flush as much pending ciphertext as possible
///
/// Returns
/// - Ready(Ok(())) if all pending is flushed or was none
/// - Pending if upstream would block
/// - Ready(Err(_)) on error
fn poll_flush_pending(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> { fn poll_flush_pending(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
loop { loop {
match &mut self.state { match &mut self.state {
@@ -606,14 +565,6 @@ impl<W: AsyncWrite + Unpin> CryptoWriter<W> {
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(n)) => {
pending.advance(n); pending.advance(n);
trace!(
flushed = n,
pending_left = pending.pending_len(),
"CryptoWriter: flushed pending ciphertext"
);
// continue loop to flush more
continue; continue;
} }
} }
@@ -643,9 +594,6 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
} }
// 1) If we have pending ciphertext, prioritize flushing it // 1) If we have pending ciphertext, prioritize flushing it
// If upstream pending
// -> still accept some plaintext ONLY if we can buffer
// all ciphertext for the accepted portion - bounded
if matches!(this.state, CryptoWriterState::Flushing { .. }) { if matches!(this.state, CryptoWriterState::Flushing { .. }) {
match this.poll_flush_pending(cx) { match this.poll_flush_pending(cx) {
Poll::Ready(Ok(())) => { Poll::Ready(Ok(())) => {
@@ -654,8 +602,6 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => { Poll::Pending => {
// Upstream blocked. Apply ideal backpressure // Upstream blocked. Apply ideal backpressure
// - accept up to remaining pending capacity
// - if no capacity -> pending
let to_accept = let to_accept =
Self::select_to_accept_for_buffering(&this.state, buf.len()); Self::select_to_accept_for_buffering(&this.state, buf.len());
@@ -670,11 +616,10 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
let plaintext = &buf[..to_accept]; let plaintext = &buf[..to_accept];
// Disjoint borrows: borrow encryptor and state separately via a match // Disjoint borrows
let encryptor = &mut this.encryptor; let encryptor = &mut this.encryptor;
let pending = Self::ensure_pending(&mut this.state); let pending = Self::ensure_pending(&mut this.state);
// Should not WouldBlock because to_accept <= remaining_capacity
if let Err(e) = pending.push_encrypted(encryptor, plaintext) { if let Err(e) = pending.push_encrypted(encryptor, plaintext) {
if e.kind() == ErrorKind::WouldBlock { if e.kind() == ErrorKind::WouldBlock {
return Poll::Pending; return Poll::Pending;
@@ -682,13 +627,6 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
return Poll::Ready(Err(e)); return Poll::Ready(Err(e));
} }
trace!(
accepted = to_accept,
pending_len = pending.pending_len(),
pending_cap = pending.remaining_capacity(),
"CryptoWriter: upstream Pending, buffered ciphertext (accepted plaintext)"
);
return Poll::Ready(Ok(to_accept)); return Poll::Ready(Ok(to_accept));
} }
} }
@@ -697,9 +635,6 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
// 2) Fast path: pending empty -> write-through // 2) Fast path: pending empty -> write-through
debug_assert!(matches!(this.state, CryptoWriterState::Idle)); debug_assert!(matches!(this.state, CryptoWriterState::Idle));
// Worst-case buffering requirement
// - If upstream becomes pending -> buffer full ciphertext for accepted bytes
// -> accept at most MAX_PENDING_WRITE per poll_write call
let to_accept = buf.len().min(MAX_PENDING_WRITE); let to_accept = buf.len().min(MAX_PENDING_WRITE);
let plaintext = &buf[..to_accept]; let plaintext = &buf[..to_accept];
@@ -708,18 +643,11 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
match Pin::new(&mut this.upstream).poll_write(cx, &this.scratch) { match Pin::new(&mut this.upstream).poll_write(cx, &this.scratch) {
Poll::Pending => { Poll::Pending => {
// Upstream blocked: buffer FULL ciphertext for accepted bytes. // Upstream blocked: buffer FULL ciphertext for accepted bytes.
// Move scratch into pending without copying.
let ciphertext = std::mem::take(&mut this.scratch); let ciphertext = std::mem::take(&mut this.scratch);
let pending = Self::ensure_pending(&mut this.state); let pending = Self::ensure_pending(&mut this.state);
pending.replace_with(ciphertext); pending.replace_with(ciphertext);
trace!(
accepted = to_accept,
pending_len = pending.pending_len(),
"CryptoWriter: write-through got Pending, buffered full ciphertext"
);
Poll::Ready(Ok(to_accept)) Poll::Ready(Ok(to_accept))
} }
@@ -736,26 +664,11 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
Poll::Ready(Ok(n)) => { Poll::Ready(Ok(n)) => {
if n == this.scratch.len() { if n == this.scratch.len() {
trace!(
accepted = to_accept,
ciphertext_len = this.scratch.len(),
"CryptoWriter: write-through wrote full ciphertext directly"
);
this.scratch.clear(); this.scratch.clear();
return Poll::Ready(Ok(to_accept)); return Poll::Ready(Ok(to_accept));
} }
// Partial upstream write of ciphertext: // Partial upstream write of ciphertext
// We accepted `to_accept` plaintext bytes, CTR already advanced for to_accept
// Must buffer the remainder ciphertext
warn!(
accepted = to_accept,
ciphertext_len = this.scratch.len(),
written_ciphertext = n,
"CryptoWriter: partial upstream write, buffering remainder"
);
// Split off remainder without copying
let remainder = this.scratch.split_off(n); let remainder = this.scratch.split_off(n);
this.scratch.clear(); this.scratch.clear();
@@ -788,7 +701,6 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
let this = self.get_mut(); let this = self.get_mut();
// Best-effort flush pending ciphertext before shutdown // Best-effort flush pending ciphertext before shutdown
// If upstream blocks, proceed to shutdown anyway
match this.poll_flush_pending(cx) { match this.poll_flush_pending(cx) {
Poll::Pending => { Poll::Pending => {
debug!( debug!(
@@ -807,9 +719,6 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
// ============= PassthroughStream ============= // ============= PassthroughStream =============
/// Passthrough stream for fast mode - no encryption/decryption /// Passthrough stream for fast mode - no encryption/decryption
///
/// Used when keys are set up so that client and Telegram use the same
/// encryption, allowing data to pass through without re-encryption
pub struct PassthroughStream<S> { pub struct PassthroughStream<S> {
inner: S, inner: S,
} }

View File

@@ -5,8 +5,10 @@
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use std::io::Result; use std::io::Result;
use std::sync::Arc;
use crate::protocol::constants::ProtoTag; use crate::protocol::constants::ProtoTag;
use crate::crypto::SecureRandom;
// ============= Frame Types ============= // ============= Frame Types =============
@@ -147,11 +149,11 @@ pub trait FrameCodec: Send + Sync {
// ============= Codec Factory ============= // ============= Codec Factory =============
/// Create a frame codec for the given protocol tag /// Create a frame codec for the given protocol tag
pub fn create_codec(proto_tag: ProtoTag) -> Box<dyn FrameCodec> { pub fn create_codec(proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Box<dyn FrameCodec> {
match proto_tag { match proto_tag {
ProtoTag::Abridged => Box::new(crate::stream::frame_codec::AbridgedCodec::new()), ProtoTag::Abridged => Box::new(crate::stream::frame_codec::AbridgedCodec::new()),
ProtoTag::Intermediate => Box::new(crate::stream::frame_codec::IntermediateCodec::new()), ProtoTag::Intermediate => Box::new(crate::stream::frame_codec::IntermediateCodec::new()),
ProtoTag::Secure => Box::new(crate::stream::frame_codec::SecureCodec::new()), ProtoTag::Secure => Box::new(crate::stream::frame_codec::SecureCodec::new(rng)),
} }
} }

View File

@@ -5,9 +5,11 @@
use bytes::{Bytes, BytesMut, BufMut}; use bytes::{Bytes, BytesMut, BufMut};
use std::io::{self, Error, ErrorKind}; use std::io::{self, Error, ErrorKind};
use std::sync::Arc;
use tokio_util::codec::{Decoder, Encoder}; use tokio_util::codec::{Decoder, Encoder};
use crate::protocol::constants::ProtoTag; use crate::protocol::constants::ProtoTag;
use crate::crypto::SecureRandom;
use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait}; use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait};
// ============= Unified Codec ============= // ============= Unified Codec =============
@@ -21,14 +23,17 @@ pub struct FrameCodec {
proto_tag: ProtoTag, proto_tag: ProtoTag,
/// Maximum allowed frame size /// Maximum allowed frame size
max_frame_size: usize, max_frame_size: usize,
/// RNG for secure padding
rng: Arc<SecureRandom>,
} }
impl FrameCodec { impl FrameCodec {
/// Create a new codec for the given protocol /// Create a new codec for the given protocol
pub fn new(proto_tag: ProtoTag) -> Self { pub fn new(proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Self {
Self { Self {
proto_tag, proto_tag,
max_frame_size: 16 * 1024 * 1024, // 16MB default max_frame_size: 16 * 1024 * 1024, // 16MB default
rng,
} }
} }
@@ -64,7 +69,7 @@ impl Encoder<Frame> for FrameCodec {
match self.proto_tag { match self.proto_tag {
ProtoTag::Abridged => encode_abridged(&frame, dst), ProtoTag::Abridged => encode_abridged(&frame, dst),
ProtoTag::Intermediate => encode_intermediate(&frame, dst), ProtoTag::Intermediate => encode_intermediate(&frame, dst),
ProtoTag::Secure => encode_secure(&frame, dst), ProtoTag::Secure => encode_secure(&frame, dst, &self.rng),
} }
} }
} }
@@ -288,9 +293,7 @@ fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame
Ok(Some(Frame::with_meta(data, meta))) Ok(Some(Frame::with_meta(data, meta)))
} }
fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> { fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::Result<()> {
use crate::crypto::random::SECURE_RANDOM;
let data = &frame.data; let data = &frame.data;
// Simple ACK: just send data // Simple ACK: just send data
@@ -303,10 +306,10 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
// Generate padding to make length not divisible by 4 // Generate padding to make length not divisible by 4
let padding_len = if data.len() % 4 == 0 { let padding_len = if data.len() % 4 == 0 {
// Add 1-3 bytes to make it non-aligned // Add 1-3 bytes to make it non-aligned
(SECURE_RANDOM.range(3) + 1) as usize (rng.range(3) + 1) as usize
} else { } else {
// Already non-aligned, can add 0-3 // Already non-aligned, can add 0-3
SECURE_RANDOM.range(4) as usize rng.range(4) as usize
}; };
let total_len = data.len() + padding_len; let total_len = data.len() + padding_len;
@@ -321,7 +324,7 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
dst.extend_from_slice(data); dst.extend_from_slice(data);
if padding_len > 0 { if padding_len > 0 {
let padding = SECURE_RANDOM.bytes(padding_len); let padding = rng.bytes(padding_len);
dst.extend_from_slice(&padding); dst.extend_from_slice(&padding);
} }
@@ -445,19 +448,21 @@ impl FrameCodecTrait for IntermediateCodec {
/// Secure Intermediate protocol codec /// Secure Intermediate protocol codec
pub struct SecureCodec { pub struct SecureCodec {
max_frame_size: usize, max_frame_size: usize,
rng: Arc<SecureRandom>,
} }
impl SecureCodec { impl SecureCodec {
pub fn new() -> Self { pub fn new(rng: Arc<SecureRandom>) -> Self {
Self { Self {
max_frame_size: 16 * 1024 * 1024, max_frame_size: 16 * 1024 * 1024,
rng,
} }
} }
} }
impl Default for SecureCodec { impl Default for SecureCodec {
fn default() -> Self { fn default() -> Self {
Self::new() Self::new(Arc::new(SecureRandom::new()))
} }
} }
@@ -474,7 +479,7 @@ impl Encoder<Frame> for SecureCodec {
type Error = io::Error; type Error = io::Error;
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> { fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
encode_secure(&frame, dst) encode_secure(&frame, dst, &self.rng)
} }
} }
@@ -485,7 +490,7 @@ impl FrameCodecTrait for SecureCodec {
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> { fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
let before = dst.len(); let before = dst.len();
encode_secure(frame, dst)?; encode_secure(frame, dst, &self.rng)?;
Ok(dst.len() - before) Ok(dst.len() - before)
} }
@@ -506,6 +511,8 @@ mod tests {
use tokio_util::codec::{FramedRead, FramedWrite}; use tokio_util::codec::{FramedRead, FramedWrite};
use tokio::io::duplex; use tokio::io::duplex;
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use crate::crypto::SecureRandom;
use std::sync::Arc;
#[tokio::test] #[tokio::test]
async fn test_framed_abridged() { async fn test_framed_abridged() {
@@ -541,8 +548,8 @@ mod tests {
async fn test_framed_secure() { async fn test_framed_secure() {
let (client, server) = duplex(4096); let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, SecureCodec::new()); let mut writer = FramedWrite::new(client, SecureCodec::new(Arc::new(SecureRandom::new())));
let mut reader = FramedRead::new(server, SecureCodec::new()); let mut reader = FramedRead::new(server, SecureCodec::new(Arc::new(SecureRandom::new())));
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]); let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
let frame = Frame::new(original.clone()); let frame = Frame::new(original.clone());
@@ -557,8 +564,8 @@ mod tests {
for proto_tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] { for proto_tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] {
let (client, server) = duplex(4096); let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag)); let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())));
let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag)); let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())));
// Use 4-byte aligned data for abridged compatibility // Use 4-byte aligned data for abridged compatibility
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]); let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
@@ -607,7 +614,7 @@ mod tests {
#[test] #[test]
fn test_frame_too_large() { fn test_frame_too_large() {
let mut codec = FrameCodec::new(ProtoTag::Intermediate) let mut codec = FrameCodec::new(ProtoTag::Intermediate, Arc::new(SecureRandom::new()))
.with_max_frame_size(100); .with_max_frame_size(100);
// Create a "frame" that claims to be very large // Create a "frame" that claims to be very large

View File

@@ -4,8 +4,8 @@ use bytes::{Bytes, BytesMut};
use std::io::{Error, ErrorKind, Result}; use std::io::{Error, ErrorKind, Result};
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use crate::protocol::constants::*; use crate::protocol::constants::*;
use crate::crypto::crc32; use crate::crypto::{crc32, SecureRandom};
use crate::crypto::random::SECURE_RANDOM; use std::sync::Arc;
use super::traits::{FrameMeta, LayeredStream}; use super::traits::{FrameMeta, LayeredStream};
// ============= Abridged (Compact) Frame ============= // ============= Abridged (Compact) Frame =============
@@ -251,11 +251,12 @@ impl<R> LayeredStream<R> for SecureIntermediateFrameReader<R> {
/// Writer for secure intermediate MTProto framing /// Writer for secure intermediate MTProto framing
pub struct SecureIntermediateFrameWriter<W> { pub struct SecureIntermediateFrameWriter<W> {
upstream: W, upstream: W,
rng: Arc<SecureRandom>,
} }
impl<W> SecureIntermediateFrameWriter<W> { impl<W> SecureIntermediateFrameWriter<W> {
pub fn new(upstream: W) -> Self { pub fn new(upstream: W, rng: Arc<SecureRandom>) -> Self {
Self { upstream } Self { upstream, rng }
} }
} }
@@ -267,8 +268,8 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
} }
// Add random padding (0-3 bytes) // Add random padding (0-3 bytes)
let padding_len = SECURE_RANDOM.range(4); let padding_len = self.rng.range(4);
let padding = SECURE_RANDOM.bytes(padding_len); let padding = self.rng.bytes(padding_len);
let total_len = data.len() + padding_len; let total_len = data.len() + padding_len;
let len_bytes = (total_len as u32).to_le_bytes(); let len_bytes = (total_len as u32).to_le_bytes();
@@ -454,11 +455,11 @@ pub enum FrameWriterKind<W> {
} }
impl<W: AsyncWrite + Unpin> FrameWriterKind<W> { impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
pub fn new(upstream: W, proto_tag: ProtoTag) -> Self { pub fn new(upstream: W, proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Self {
match proto_tag { match proto_tag {
ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)), ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)),
ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)), ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)),
ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream)), ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream, rng)),
} }
} }
@@ -483,6 +484,8 @@ impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
mod tests { mod tests {
use super::*; use super::*;
use tokio::io::duplex; use tokio::io::duplex;
use std::sync::Arc;
use crate::crypto::SecureRandom;
#[tokio::test] #[tokio::test]
async fn test_abridged_roundtrip() { async fn test_abridged_roundtrip() {
@@ -539,7 +542,7 @@ mod tests {
async fn test_secure_intermediate_padding() { async fn test_secure_intermediate_padding() {
let (client, server) = duplex(1024); let (client, server) = duplex(1024);
let mut writer = SecureIntermediateFrameWriter::new(client); let mut writer = SecureIntermediateFrameWriter::new(client, Arc::new(SecureRandom::new()));
let mut reader = SecureIntermediateFrameReader::new(server); let mut reader = SecureIntermediateFrameReader::new(server);
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8]; let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
@@ -572,7 +575,7 @@ mod tests {
async fn test_frame_reader_kind() { async fn test_frame_reader_kind() {
let (client, server) = duplex(1024); let (client, server) = duplex(1024);
let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate); let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate, Arc::new(SecureRandom::new()));
let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate); let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate);
let data = vec![1u8, 2, 3, 4]; let data = vec![1u8, 2, 3, 4];

View File

@@ -10,4 +10,4 @@ pub use pool::ConnectionPool;
pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol}; pub use proxy_protocol::{ProxyProtocolInfo, parse_proxy_protocol};
pub use socket::*; pub use socket::*;
pub use socks::*; pub use socks::*;
pub use upstream::UpstreamManager; pub use upstream::{UpstreamManager, StartupPingResult, DcPingResult};

View File

@@ -1,26 +1,128 @@
//! Upstream Management //! Upstream Management with per-DC latency-weighted selection
use std::net::{SocketAddr, IpAddr}; use std::net::{SocketAddr, IpAddr};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tokio::time::Instant;
use rand::Rng; use rand::Rng;
use tracing::{debug, warn, error, info}; use tracing::{debug, warn, info, trace};
use crate::config::{UpstreamConfig, UpstreamType}; use crate::config::{UpstreamConfig, UpstreamType};
use crate::error::{Result, ProxyError}; use crate::error::{Result, ProxyError};
use crate::protocol::constants::{TG_DATACENTERS_V4, TG_DATACENTERS_V6, TG_DATACENTER_PORT};
use crate::transport::socket::create_outgoing_socket_bound; use crate::transport::socket::create_outgoing_socket_bound;
use crate::transport::socks::{connect_socks4, connect_socks5}; use crate::transport::socks::{connect_socks4, connect_socks5};
/// Number of Telegram datacenters
const NUM_DCS: usize = 5;
// ============= RTT Tracking =============
#[derive(Debug, Clone, Copy)]
struct LatencyEma {
value_ms: Option<f64>,
alpha: f64,
}
impl LatencyEma {
const fn new(alpha: f64) -> Self {
Self { value_ms: None, alpha }
}
fn update(&mut self, sample_ms: f64) {
self.value_ms = Some(match self.value_ms {
None => sample_ms,
Some(prev) => prev * (1.0 - self.alpha) + sample_ms * self.alpha,
});
}
fn get(&self) -> Option<f64> {
self.value_ms
}
}
// ============= Upstream State =============
#[derive(Debug)] #[derive(Debug)]
struct UpstreamState { struct UpstreamState {
config: UpstreamConfig, config: UpstreamConfig,
healthy: bool, healthy: bool,
fails: u32, fails: u32,
last_check: std::time::Instant, last_check: std::time::Instant,
/// Per-DC latency EMA (index 0 = DC1, index 4 = DC5)
dc_latency: [LatencyEma; NUM_DCS],
} }
impl UpstreamState {
fn new(config: UpstreamConfig) -> Self {
Self {
config,
healthy: true,
fails: 0,
last_check: std::time::Instant::now(),
dc_latency: [LatencyEma::new(0.3); NUM_DCS],
}
}
/// Map DC index to latency array slot (0..NUM_DCS).
///
/// Matches the C implementation's `mf_cluster_lookup` behavior:
/// - Standard DCs ±1..±5 → direct mapping to array index 0..4
/// - Unknown DCs (CDN, media, etc.) → default DC slot (index 1 = DC 2)
/// This matches Telegram's `default 2;` in proxy-multi.conf.
/// - There is NO modular arithmetic in the C implementation.
fn dc_array_idx(dc_idx: i16) -> Option<usize> {
let abs_dc = dc_idx.unsigned_abs() as usize;
if abs_dc == 0 {
return None;
}
if abs_dc >= 1 && abs_dc <= NUM_DCS {
Some(abs_dc - 1)
} else {
// Unknown DC → default cluster (DC 2, index 1)
// Same as C: mf_cluster_lookup returns default_cluster
Some(1)
}
}
/// Get latency for a specific DC, falling back to average across all known DCs
fn effective_latency(&self, dc_idx: Option<i16>) -> Option<f64> {
// Try DC-specific latency first
if let Some(di) = dc_idx.and_then(Self::dc_array_idx) {
if let Some(ms) = self.dc_latency[di].get() {
return Some(ms);
}
}
// Fallback: average of all known DC latencies
let (sum, count) = self.dc_latency.iter()
.filter_map(|l| l.get())
.fold((0.0, 0u32), |(s, c), v| (s + v, c + 1));
if count > 0 { Some(sum / count as f64) } else { None }
}
}
/// Result of a single DC ping
#[derive(Debug, Clone)]
pub struct DcPingResult {
pub dc_idx: usize,
pub dc_addr: SocketAddr,
pub rtt_ms: Option<f64>,
pub error: Option<String>,
}
/// Result of startup ping for one upstream
#[derive(Debug, Clone)]
pub struct StartupPingResult {
pub results: Vec<DcPingResult>,
pub upstream_name: String,
}
// ============= Upstream Manager =============
#[derive(Clone)] #[derive(Clone)]
pub struct UpstreamManager { pub struct UpstreamManager {
upstreams: Arc<RwLock<Vec<UpstreamState>>>, upstreams: Arc<RwLock<Vec<UpstreamState>>>,
@@ -30,12 +132,7 @@ impl UpstreamManager {
pub fn new(configs: Vec<UpstreamConfig>) -> Self { pub fn new(configs: Vec<UpstreamConfig>) -> Self {
let states = configs.into_iter() let states = configs.into_iter()
.filter(|c| c.enabled) .filter(|c| c.enabled)
.map(|c| UpstreamState { .map(UpstreamState::new)
config: c,
healthy: true, // Optimistic start
fails: 0,
last_check: std::time::Instant::now(),
})
.collect(); .collect();
Self { Self {
@@ -43,48 +140,78 @@ impl UpstreamManager {
} }
} }
/// Select an upstream using Weighted Round Robin (simplified) /// Select upstream using latency-weighted random selection.
async fn select_upstream(&self) -> Option<usize> { ///
/// `effective_weight = config_weight × latency_factor`
///
/// where `latency_factor = 1000 / latency_ms` if latency is known,
/// or `1.0` if no latency data is available.
///
/// This means a 50ms upstream gets factor 20, a 200ms upstream gets
/// factor 5 — the faster route is 4× more likely to be chosen
/// (all else being equal).
async fn select_upstream(&self, dc_idx: Option<i16>) -> Option<usize> {
let upstreams = self.upstreams.read().await; let upstreams = self.upstreams.read().await;
if upstreams.is_empty() { if upstreams.is_empty() {
return None; return None;
} }
let healthy_indices: Vec<usize> = upstreams.iter() let healthy: Vec<usize> = upstreams.iter()
.enumerate() .enumerate()
.filter(|(_, u)| u.healthy) .filter(|(_, u)| u.healthy)
.map(|(i, _)| i) .map(|(i, _)| i)
.collect(); .collect();
if healthy_indices.is_empty() { if healthy.is_empty() {
// If all unhealthy, try any random one // All unhealthy — pick any
return Some(rand::thread_rng().gen_range(0..upstreams.len())); return Some(rand::rng().gen_range(0..upstreams.len()));
} }
// Weighted selection if healthy.len() == 1 {
let total_weight: u32 = healthy_indices.iter() return Some(healthy[0]);
.map(|&i| upstreams[i].config.weight as u32)
.sum();
if total_weight == 0 {
return Some(healthy_indices[rand::thread_rng().gen_range(0..healthy_indices.len())]);
} }
let mut choice = rand::thread_rng().gen_range(0..total_weight); // Calculate latency-weighted scores
let weights: Vec<(usize, f64)> = healthy.iter().map(|&i| {
let base = upstreams[i].config.weight as f64;
let latency_factor = upstreams[i].effective_latency(dc_idx)
.map(|ms| if ms > 1.0 { 1000.0 / ms } else { 1000.0 })
.unwrap_or(1.0);
for &idx in &healthy_indices { (i, base * latency_factor)
let weight = upstreams[idx].config.weight as u32; }).collect();
let total: f64 = weights.iter().map(|(_, w)| w).sum();
if total <= 0.0 {
return Some(healthy[rand::rng().gen_range(0..healthy.len())]);
}
let mut choice: f64 = rand::rng().gen_range(0.0..total);
for &(idx, weight) in &weights {
if choice < weight { if choice < weight {
trace!(
upstream = idx,
dc = ?dc_idx,
weight = format!("{:.2}", weight),
total = format!("{:.2}", total),
"Upstream selected"
);
return Some(idx); return Some(idx);
} }
choice -= weight; choice -= weight;
} }
Some(healthy_indices[0]) Some(healthy[0])
} }
pub async fn connect(&self, target: SocketAddr) -> Result<TcpStream> { /// Connect to target through a selected upstream.
let idx = self.select_upstream().await ///
/// `dc_idx` is used for latency-based upstream selection and RTT tracking.
/// Pass `None` if DC index is unknown.
pub async fn connect(&self, target: SocketAddr, dc_idx: Option<i16>) -> Result<TcpStream> {
let idx = self.select_upstream(dc_idx).await
.ok_or_else(|| ProxyError::Config("No upstreams available".to_string()))?; .ok_or_else(|| ProxyError::Config("No upstreams available".to_string()))?;
let upstream = { let upstream = {
@@ -92,28 +219,34 @@ impl UpstreamManager {
guard[idx].config.clone() guard[idx].config.clone()
}; };
let start = Instant::now();
match self.connect_via_upstream(&upstream, target).await { match self.connect_via_upstream(&upstream, target).await {
Ok(stream) => { Ok(stream) => {
// Mark success let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
let mut guard = self.upstreams.write().await; let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(idx) { if let Some(u) = guard.get_mut(idx) {
if !u.healthy { if !u.healthy {
debug!("Upstream recovered: {:?}", u.config); debug!(rtt_ms = format!("{:.1}", rtt_ms), "Upstream recovered");
} }
u.healthy = true; u.healthy = true;
u.fails = 0; u.fails = 0;
// Store per-DC latency
if let Some(di) = dc_idx.and_then(UpstreamState::dc_array_idx) {
u.dc_latency[di].update(rtt_ms);
}
} }
Ok(stream) Ok(stream)
}, },
Err(e) => { Err(e) => {
// Mark failure
let mut guard = self.upstreams.write().await; let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(idx) { if let Some(u) = guard.get_mut(idx) {
u.fails += 1; u.fails += 1;
warn!("Failed to connect via upstream {:?}: {}. Fails: {}", u.config, e, u.fails); warn!(fails = u.fails, "Upstream failed: {}", e);
if u.fails > 3 { if u.fails > 3 {
u.healthy = false; u.healthy = false;
warn!("Upstream disabled due to failures: {:?}", u.config); warn!("Upstream marked unhealthy");
} }
} }
Err(e) Err(e)
@@ -129,18 +262,16 @@ impl UpstreamManager {
let socket = create_outgoing_socket_bound(target, bind_ip)?; let socket = create_outgoing_socket_bound(target, bind_ip)?;
// Non-blocking connect logic
socket.set_nonblocking(true)?; socket.set_nonblocking(true)?;
match socket.connect(&target.into()) { match socket.connect(&target.into()) {
Ok(()) => {}, Ok(()) => {},
Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {}, Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {},
Err(err) => return Err(ProxyError::Io(err)), Err(err) => return Err(ProxyError::Io(err)),
} }
let std_stream: std::net::TcpStream = socket.into(); let std_stream: std::net::TcpStream = socket.into();
let stream = TcpStream::from_std(std_stream)?; let stream = TcpStream::from_std(std_stream)?;
// Wait for connection to complete
stream.writable().await?; stream.writable().await?;
if let Some(e) = stream.take_error()? { if let Some(e) = stream.take_error()? {
return Err(ProxyError::Io(e)); return Err(ProxyError::Io(e));
@@ -149,8 +280,6 @@ impl UpstreamManager {
Ok(stream) Ok(stream)
}, },
UpstreamType::Socks4 { address, interface, user_id } => { UpstreamType::Socks4 { address, interface, user_id } => {
info!("Connecting to target {} via SOCKS4 proxy {}", target, address);
let proxy_addr: SocketAddr = address.parse() let proxy_addr: SocketAddr = address.parse()
.map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?; .map_err(|_| ProxyError::Config("Invalid SOCKS4 address".to_string()))?;
@@ -159,18 +288,16 @@ impl UpstreamManager {
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?; let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
// Non-blocking connect logic
socket.set_nonblocking(true)?; socket.set_nonblocking(true)?;
match socket.connect(&proxy_addr.into()) { match socket.connect(&proxy_addr.into()) {
Ok(()) => {}, Ok(()) => {},
Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {}, Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {},
Err(err) => return Err(ProxyError::Io(err)), Err(err) => return Err(ProxyError::Io(err)),
} }
let std_stream: std::net::TcpStream = socket.into(); let std_stream: std::net::TcpStream = socket.into();
let mut stream = TcpStream::from_std(std_stream)?; let mut stream = TcpStream::from_std(std_stream)?;
// Wait for connection to complete
stream.writable().await?; stream.writable().await?;
if let Some(e) = stream.take_error()? { if let Some(e) = stream.take_error()? {
return Err(ProxyError::Io(e)); return Err(ProxyError::Io(e));
@@ -180,8 +307,6 @@ impl UpstreamManager {
Ok(stream) Ok(stream)
}, },
UpstreamType::Socks5 { address, interface, username, password } => { UpstreamType::Socks5 { address, interface, username, password } => {
info!("Connecting to target {} via SOCKS5 proxy {}", target, address);
let proxy_addr: SocketAddr = address.parse() let proxy_addr: SocketAddr = address.parse()
.map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?; .map_err(|_| ProxyError::Config("Invalid SOCKS5 address".to_string()))?;
@@ -190,18 +315,16 @@ impl UpstreamManager {
let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?; let socket = create_outgoing_socket_bound(proxy_addr, bind_ip)?;
// Non-blocking connect logic
socket.set_nonblocking(true)?; socket.set_nonblocking(true)?;
match socket.connect(&proxy_addr.into()) { match socket.connect(&proxy_addr.into()) {
Ok(()) => {}, Ok(()) => {},
Err(err) if err.raw_os_error() == Some(115) || err.kind() == std::io::ErrorKind::WouldBlock => {}, Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) || err.kind() == std::io::ErrorKind::WouldBlock => {},
Err(err) => return Err(ProxyError::Io(err)), Err(err) => return Err(ProxyError::Io(err)),
} }
let std_stream: std::net::TcpStream = socket.into(); let std_stream: std::net::TcpStream = socket.into();
let mut stream = TcpStream::from_std(std_stream)?; let mut stream = TcpStream::from_std(std_stream)?;
// Wait for connection to complete
stream.writable().await?; stream.writable().await?;
if let Some(e) = stream.take_error()? { if let Some(e) = stream.take_error()? {
return Err(ProxyError::Io(e)); return Err(ProxyError::Io(e));
@@ -213,13 +336,100 @@ impl UpstreamManager {
} }
} }
/// Background task to check health // ============= Startup Ping =============
pub async fn run_health_checks(&self) {
// Simple TCP connect check to a known stable DC (e.g. 149.154.167.50:443 - DC2) /// Ping all Telegram DCs through all upstreams.
let check_target: SocketAddr = "149.154.167.50:443".parse().unwrap(); pub async fn ping_all_dcs(&self, prefer_ipv6: bool) -> Vec<StartupPingResult> {
let upstreams: Vec<(usize, UpstreamConfig)> = {
let guard = self.upstreams.read().await;
guard.iter().enumerate()
.map(|(i, u)| (i, u.config.clone()))
.collect()
};
let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 };
let mut all_results = Vec::new();
for (upstream_idx, upstream_config) in &upstreams {
let upstream_name = match &upstream_config.upstream_type {
UpstreamType::Direct { interface } => {
format!("direct{}", interface.as_ref().map(|i| format!(" ({})", i)).unwrap_or_default())
}
UpstreamType::Socks4 { address, .. } => format!("socks4://{}", address),
UpstreamType::Socks5 { address, .. } => format!("socks5://{}", address),
};
let mut dc_results = Vec::new();
for (dc_zero_idx, dc_ip) in datacenters.iter().enumerate() {
let dc_addr = SocketAddr::new(*dc_ip, TG_DATACENTER_PORT);
let ping_result = tokio::time::timeout(
Duration::from_secs(5),
self.ping_single_dc(upstream_config, dc_addr)
).await;
let result = match ping_result {
Ok(Ok(rtt_ms)) => {
// Store per-DC latency
let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(*upstream_idx) {
u.dc_latency[dc_zero_idx].update(rtt_ms);
}
DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr,
rtt_ms: Some(rtt_ms),
error: None,
}
}
Ok(Err(e)) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr,
rtt_ms: None,
error: Some(e.to_string()),
},
Err(_) => DcPingResult {
dc_idx: dc_zero_idx + 1,
dc_addr,
rtt_ms: None,
error: Some("timeout (5s)".to_string()),
},
};
dc_results.push(result);
}
all_results.push(StartupPingResult {
results: dc_results,
upstream_name,
});
}
all_results
}
async fn ping_single_dc(&self, config: &UpstreamConfig, target: SocketAddr) -> Result<f64> {
let start = Instant::now();
let _stream = self.connect_via_upstream(config, target).await?;
Ok(start.elapsed().as_secs_f64() * 1000.0)
}
// ============= Health Checks =============
/// Background health check: rotates through DCs, 30s interval.
pub async fn run_health_checks(&self, prefer_ipv6: bool) {
let datacenters = if prefer_ipv6 { &*TG_DATACENTERS_V6 } else { &*TG_DATACENTERS_V4 };
let mut dc_rotation = 0usize;
loop { loop {
tokio::time::sleep(Duration::from_secs(60)).await; tokio::time::sleep(Duration::from_secs(30)).await;
let dc_zero_idx = dc_rotation % datacenters.len();
dc_rotation += 1;
let check_target = SocketAddr::new(datacenters[dc_zero_idx], TG_DATACENTER_PORT);
let count = self.upstreams.read().await.len(); let count = self.upstreams.read().await.len();
for i in 0..count { for i in 0..count {
@@ -228,6 +438,7 @@ impl UpstreamManager {
guard[i].config.clone() guard[i].config.clone()
}; };
let start = Instant::now();
let result = tokio::time::timeout( let result = tokio::time::timeout(
Duration::from_secs(10), Duration::from_secs(10),
self.connect_via_upstream(&config, check_target) self.connect_via_upstream(&config, check_target)
@@ -238,18 +449,36 @@ impl UpstreamManager {
match result { match result {
Ok(Ok(_stream)) => { Ok(Ok(_stream)) => {
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
u.dc_latency[dc_zero_idx].update(rtt_ms);
if !u.healthy { if !u.healthy {
debug!("Upstream recovered: {:?}", u.config); info!(
rtt = format!("{:.0}ms", rtt_ms),
dc = dc_zero_idx + 1,
"Upstream recovered"
);
} }
u.healthy = true; u.healthy = true;
u.fails = 0; u.fails = 0;
} }
Ok(Err(e)) => { Ok(Err(e)) => {
debug!("Health check failed for {:?}: {}", u.config, e); u.fails += 1;
// Don't mark unhealthy immediately in background check debug!(dc = dc_zero_idx + 1, fails = u.fails,
"Health check failed: {}", e);
if u.fails > 3 {
u.healthy = false;
warn!("Upstream unhealthy (fails)");
}
} }
Err(_) => { Err(_) => {
debug!("Health check timeout for {:?}", u.config); u.fails += 1;
debug!(dc = dc_zero_idx + 1, fails = u.fails,
"Health check timeout");
if u.fails > 3 {
u.healthy = false;
warn!("Upstream unhealthy (timeout)");
}
} }
} }
u.last_check = std::time::Instant::now(); u.last_check = std::time::Instant::now();